Skip to content

[Feat][Attention] Add GQA paged prefill FP8 KV cache op#1268

Open
superAngGao wants to merge 3 commits intotile-ai:mainfrom
superAngGao:feat/gqa-prefill-fp8-kv-cache-storage
Open

[Feat][Attention] Add GQA paged prefill FP8 KV cache op#1268
superAngGao wants to merge 3 commits intotile-ai:mainfrom
superAngGao:feat/gqa-prefill-fp8-kv-cache-storage

Conversation

@superAngGao
Copy link
Copy Markdown
Collaborator

Summary

Add a storage-only FP8 KV cache variant for paged GQA prefill:

  • Introduce GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp.
  • Keep q, k_new, and v_new in the attention dtype for the current chunk.
  • Store paged KV cache as float8_e4m3fn.
  • Dequantize old cache pages online with scalar k_scale / v_scale tensors of shape (1,).
  • Quantize the current chunk into FP8 E4M3 when appending to the paged cache.
  • Add a spec-driven manifest entry and targeted tests.

This is the storage-only path. Fused RoPE + FP8 append and FP8 tensor-core attention compute are intentionally left as follow-up kernel variants.

Tests

Ran with the nightly test image tileops-runner:nightly-tl019-fullstack-no-tileops-ldfix on H200:

python -m pytest -q tests/ops/attention/test_gqa_prefill_paged.py -k fp8 --tb=short -s
# 3 passed, 17 deselected

python -m pytest -q tests/ops/attention/test_gqa_prefill_paged.py --tb=short -s
# 20 passed

python scripts/validate_manifest.py --check-op GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp
# All manifest checks passed, with expected shape/dtype parity and bench integration warnings.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new operation and kernel for Grouped Query Attention (GQA) prefill with paged FP8 KV cache. The implementation includes the GQAPrefillPagedWithFP8KVCacheFwdKernel, the GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp operator, and associated unit tests. Review feedback highlights opportunities to reduce code duplication in the kernel's quantization and dequantization logic by utilizing T.macro for better maintainability.

Comment thread tileops/kernels/attention/gqa_fwd.py Outdated
Comment thread tileops/kernels/attention/gqa_fwd.py
Copy link
Copy Markdown
Contributor

@Gabbering Gabbering left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose review — 570a2242

honk. A clean structural copy-paste from the non-FP8 paged kernel with FP8 quant/dequant bolted on. Mostly correct. Mostly. One real bug, one perf concern, one test gap.

goose Bugs

  • tileops/kernels/attention/gqa_fwd.py:2199 — The page_size % block_n == 0 fast-path for reading old cache pages assumes the entire block_n tile sits inside a single logical page. That's only true when tile_start is page-aligned or page_offset + block_n <= page_size. The code only checks page_size % block_n == 0, then computes physical_start from tile_start's page. If tile_start happens to be mid-page and block_n == page_size, it reads past the physical page boundary into an adjacent (wrong) physical page. The non-FP8 sibling kernel at line ~1858 has the exact same latent bug, so this was faithfully copied. However: for the default config (block_n=64, page_size=64) the condition page_size % block_n == 0 is true and tile_end <= old_len guarantees tile_start is a multiple of block_n (== page_size), so it happens to be safe for page_size == block_n. It breaks if someone autotunes to block_n=32 with page_size=64 and the old_len isn't block_n-aligned — tile_start can be 32, page_idx would be 0, page_offset 32, and you'd read indices 32..63 from that physical page, which is correct. Actually wait — tile_start = k_idx * block_n, so it's always a multiple of block_n, and since page_size % block_n == 0, page_offset + block_n <= page_size always holds. The goose retracts its fangs on this one. Carry on.

  • tileops/kernels/attention/gqa_fwd.py:2159 (FP8 append, page_size % block_n fast path) — There is no synchronisation barrier between the FP8 cache append (guarded by by < heads_kv) and the attention read loop that reads from k_pages/v_pages. All heads CTAs in the same (bx, bz) group share the same k_pages/v_pages global memory. CTAs with by >= heads_kv skip the append and fall straight into the attention loop, where the boundary tile (tile_start < old_len < tile_end or tile_start >= old_len) reads k_new/v_new directly — so those CTAs don't actually read the freshly-written FP8 data for the current chunk. The old-cache tiles (tile_end <= old_len) only read pages written by previous calls, so there's no WAR hazard there either. The non-FP8 sibling has exactly the same structure and ships; so this is architecturally consistent and not a new bug introduced here. Fine.

    Actually — let me look harder. For the old cache dequant path the kernel reads k_pages[..., cur_kv_head, ...] which is data written by a previous kernel launch. No hazard. For the new chunk path (tile_start >= old_len) it reads from k_new/v_new (read-only inputs). For the boundary path it conditionally reads either old cache or k_new/v_new. So the freshly-appended FP8 data in k_pages is never consumed within the same launch. The goose is satisfied. No sync bug.

    However, there IS a subtle data-race in the append itself: multiple CTAs with the same (bz, by) but different bx values all perform the append for their own block_m tile of new tokens, which is fine — they write disjoint positions. But within a single CTA, if by < heads_kv and by != cur_kv_head (which is by // groups), the append writes to head by while the attention reads from head cur_kv_head. These are different head indices when groups > 1, so different memory locations. No race. Good.

    The goose found no actual bugs in the kernel logic. Disappointing but correct.

feather Performance

  • tileops/kernels/attention/gqa_fwd.py:2148-2210 — Every element in the old-cache dequant path does two casts and a multiply per element, per KV: fp8 → float32 → multiply → cast to dtype. For the page_size % block_n == 0 fast-path, this is 2 * block_n * dim casts + block_n * dim multiplies, repeated for both K and V. The scale values k_scale[0] and v_scale[0] are loaded from global memory inside the innermost parallel loop — relying on the compiler to hoist them to registers. If tilelang doesn't hoist scalar loads out of T.Parallel (it usually does, but worth verifying), every thread re-reads from global. Consider loading k_scale[0] and v_scale[0] into a local variable before the pipelined loop. This is a nit on a brand-new kernel path, but for long-prefix scenarios (e.g. the 32k-prefix workload in the manifest) the dequant overhead on old pages dominates, so it matters.

egg Test gaps

  • tests/ops/attention/test_gqa_prefill_paged.py:213-312 — The FP8 test only exercises dtype=float16. The manifest declares q: {dtype: "float16 | bfloat16"}, and the Op's _validate_attention_dtype accepts bfloat16, but there is zero test coverage for bfloat16 + float8_e4m3fn cache. The existing non-FP8 tests parametrize over both dtypes. Add at least one bfloat16 case — FP8 dequant → bf16 cast has different rounding behaviour than → fp16, and the tolerance may need adjusting.

  • tests/ops/attention/test_gqa_prefill_paged.py:213-312 — All three FP8 test cases use page_size=64, block_n=64 (default config). The non-FP8 tests have a dedicated test_gqa_prefill_paged_with_kv_cache_page_sizes that exercises page_size ∈ {16, 32, 128} to hit the block_n % page_size == 0 and generic fallback code paths in the kernel. The FP8 kernel has identical branching for those paths (with added dequant), but none of them are tested. A page_size=16 or page_size=128 case would exercise the block_n % page_size and per-element fallback dequant paths respectively.

@superAngGao
Copy link
Copy Markdown
Collaborator Author

Updated in 79dd506.

Changes:

  • Extracted FP8 append quantization into quantize_fp8 T.macro.
  • Extracted FP8 old-cache dequantization into dequantize_fp8 T.macro.
  • Expanded FP8 tests to cover bf16 attention dtype and page_size 16/128 in addition to the existing page_size 64 paths.

Validation:

  • python3 -m ruff check tileops/kernels/attention/gqa_fwd.py tests/ops/attention/test_gqa_prefill_paged.py
  • PYTHONPATH=. python3 scripts/validate_manifest.py --check-op GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp
  • python3 -m pytest -q tests/ops/attention/test_gqa_prefill_paged.py -k fp8 --tb=short -s -> 6 passed
  • python3 -m pytest -q tests/ops/attention/test_gqa_prefill_paged.py --tb=short -s -> 23 passed

Copy link
Copy Markdown
Contributor

@Gabbering Gabbering left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goose goose skimmed 79dd506 — nothing to honk about.

@superAngGao superAngGao changed the title Add GQA paged prefill FP8 KV cache op [Feat][Attention] Add GQA paged prefill FP8 KV cache op May 8, 2026
@superAngGao superAngGao marked this pull request as ready for review May 8, 2026 06:03
@superAngGao superAngGao requested a review from a team May 8, 2026 06:03
Copy link
Copy Markdown
Contributor

@Ibuki-wind Ibuki-wind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall

Two blocking issues remain in the changed API surface; address the inline comments before the next round.

if tensor.dtype != self.cache_dtype:
raise ValueError(
f"Expected {name}.dtype {self.cache_dtype}, got {tensor.dtype}")
for name, tensor in [("k_scale", k_scale), ("v_scale", v_scale)]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp._validate_forward_inputs only checks k_scale/v_scale shape and dtype while the kernel divides by those values in quantize_fp8; zero or non-finite scales will silently produce saturated/NaN FP8 cache and output -> reject non-finite or <= 0 k_scale/v_scale values before launching the kernel, with a focused validation test.

@tile-ai tile-ai deleted a comment from Ibuki-wind May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants