diff --git a/tests/ops/attention/test_gqa_prefill_paged.py b/tests/ops/attention/test_gqa_prefill_paged.py index d7fd5e0a..0b80cf6a 100644 --- a/tests/ops/attention/test_gqa_prefill_paged.py +++ b/tests/ops/attention/test_gqa_prefill_paged.py @@ -5,7 +5,11 @@ import pytest import torch -from tileops.ops import GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp, RopeNeoxPositionIdsOp +from tileops.ops import ( + GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp, + GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp, + RopeNeoxPositionIdsOp, +) _PREFILL_PAGED_TOLERANCE = { torch.float16: (5e-3, 1e-5), @@ -206,6 +210,107 @@ def test_gqa_prefill_paged_with_kv_cache_fwd( torch.testing.assert_close(v_pages[physical_pos], v_pages_before[physical_pos]) +@pytest.mark.smoke +@pytest.mark.parametrize("is_causal, softcap, dtype, page_size", [ + pytest.param(True, None, torch.float16, 64, id="causal-fp16-page64"), + pytest.param(False, None, torch.float16, 64, id="noncausal-fp16-page64"), + pytest.param(True, 2.0, torch.float16, 64, id="causal-softcap-fp16-page64"), + pytest.param(True, None, torch.bfloat16, 64, id="causal-bf16-page64"), + pytest.param(True, None, torch.float16, 16, id="causal-fp16-page16"), + pytest.param(True, None, torch.float16, 128, id="causal-fp16-page128"), +]) +def test_gqa_prefill_paged_with_fp8_kv_cache_fwd( + is_causal: bool, + softcap: float | None, + dtype: torch.dtype, + page_size: int, +) -> None: + q_lens = [33, 48] + old_lens = [67, 80] + batch, heads, heads_kv, dim = 2, 8, 2, 64 + cache_dtype = torch.float8_e4m3fn + max_pages_per_req = 8 + num_pages = batch * max_pages_per_req + total_q = sum(q_lens) + block_table = _make_block_table(batch, max_pages_per_req) + cu_seqlens_q = _make_cu_seqlens(q_lens) + cache_seqlens = torch.tensor(old_lens, device="cuda", dtype=torch.int32) + k_scale = torch.tensor([0.02], device="cuda", dtype=torch.float32) + v_scale = torch.tensor([0.02], device="cuda", dtype=torch.float32) + + q = torch.randn(total_q, heads, dim, device="cuda", dtype=dtype).contiguous() + k_new = (torch.randn(total_q, heads_kv, dim, device="cuda", dtype=dtype) * + 0.5).contiguous() + v_new = (torch.randn(total_q, heads_kv, dim, device="cuda", dtype=dtype) * + 0.5).contiguous() + k_pages = torch.zeros(num_pages * page_size, heads_kv, dim, device="cuda", + dtype=cache_dtype).contiguous() + v_pages = torch.zeros_like(k_pages) + k_old = [ + (torch.randn(old_len, heads_kv, dim, device="cuda", dtype=dtype) * 0.5).contiguous() + for old_len in old_lens + ] + v_old = [ + (torch.randn(old_len, heads_kv, dim, device="cuda", dtype=dtype) * 0.5).contiguous() + for old_len in old_lens + ] + k_old_quant = [(k_b.float() / k_scale[0]).to(cache_dtype).contiguous() for k_b in k_old] + v_old_quant = [(v_b.float() / v_scale[0]).to(cache_dtype).contiguous() for v_b in v_old] + _fill_paged_cache_from_logical( + k_pages, v_pages, k_old_quant, v_old_quant, block_table, page_size) + k_pages_before = k_pages.clone() + v_pages_before = v_pages.clone() + k_old_dequant = [(k_b.float() * k_scale[0]).to(dtype).contiguous() for k_b in k_old_quant] + v_old_dequant = [(v_b.float() * v_scale[0]).to(dtype).contiguous() for v_b in v_old_quant] + ref = _gqa_prefill_paged_ref( + q, + k_new, + v_new, + k_old_dequant, + v_old_dequant, + cu_seqlens_q, + batch=batch, + heads=heads, + heads_kv=heads_kv, + is_causal=is_causal, + softcap=softcap, + ) + op = GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp( + batch=batch, + heads=heads, + heads_kv=heads_kv, + max_pages_per_req=max_pages_per_req, + page_size=page_size, + dim=dim, + is_causal=is_causal, + dtype=dtype, + softcap=softcap, + ) + + output = op( + q, k_new, v_new, k_pages, v_pages, k_scale, v_scale, cu_seqlens_q, cache_seqlens, + block_table, max(q_lens)) + assert isinstance(output, torch.Tensor) + torch.testing.assert_close(output, ref, atol=8e-2, rtol=2e-2) + + for b, (q_len, old_len) in enumerate(zip(q_lens, old_lens, strict=True)): + q_start = int(cu_seqlens_q[b].item()) + for i in range(q_len): + physical_pos = _physical_pos(block_table, b, old_len + i, page_size) + expected_k = (k_new[q_start + i].float() / k_scale[0]).to(cache_dtype).float() + expected_v = (v_new[q_start + i].float() / v_scale[0]).to(cache_dtype).float() + torch.testing.assert_close(k_pages[physical_pos].float(), expected_k, atol=0, rtol=0) + torch.testing.assert_close(v_pages[physical_pos].float(), expected_v, atol=0, rtol=0) + + for b, old_len in enumerate(old_lens): + for pos in range(old_len): + physical_pos = _physical_pos(block_table, b, pos, page_size) + torch.testing.assert_close( + k_pages[physical_pos].float(), k_pages_before[physical_pos].float()) + torch.testing.assert_close( + v_pages[physical_pos].float(), v_pages_before[physical_pos].float()) + + @pytest.mark.smoke @pytest.mark.parametrize("rotary_dim, is_causal, softcap", [ pytest.param(None, True, None, id="full-causal"), diff --git a/tileops/kernels/__init__.py b/tileops/kernels/__init__.py index df01a9a2..ffe84f37 100644 --- a/tileops/kernels/__init__.py +++ b/tileops/kernels/__init__.py @@ -8,6 +8,7 @@ GQAFwdKernel, GQAFwdWgmmaPipelinedKernel, GQAPrefillFwdKernel, + GQAPrefillPagedWithFP8KVCacheFwdKernel, GQAPrefillPagedWithKVCacheFwdKernel, GQAPrefillPagedWithKVCacheRopeAppendKernel, GQAPrefillPagedWithKVCacheRopeFwdKernel, @@ -106,6 +107,7 @@ "GQAFwdKernel", "GQAFwdWgmmaPipelinedKernel", "GQAPrefillFwdKernel", + "GQAPrefillPagedWithFP8KVCacheFwdKernel", "GQAPrefillPagedWithKVCacheFwdKernel", "GQAPrefillPagedWithKVCacheRopeAppendKernel", "GQAPrefillPagedWithKVCacheRopeFwdKernel", diff --git a/tileops/kernels/attention/__init__.py b/tileops/kernels/attention/__init__.py index a7c858d5..d5e95d8c 100644 --- a/tileops/kernels/attention/__init__.py +++ b/tileops/kernels/attention/__init__.py @@ -18,6 +18,7 @@ GQAFwdKernel, GQAFwdWgmmaPipelinedKernel, GQAPrefillFwdKernel, + GQAPrefillPagedWithFP8KVCacheFwdKernel, GQAPrefillPagedWithKVCacheFwdKernel, GQAPrefillPagedWithKVCacheRopeAppendKernel, GQAPrefillPagedWithKVCacheRopeFwdKernel, @@ -52,6 +53,7 @@ "GQAFwdWsPersistentCausalKernel", "GQAFwdWsPersistentKernel", "GQAPrefillFwdKernel", + "GQAPrefillPagedWithFP8KVCacheFwdKernel", "GQAPrefillPagedWithKVCacheFwdKernel", "GQAPrefillPagedWithKVCacheRopeAppendKernel", "GQAPrefillPagedWithKVCacheRopeFwdKernel", diff --git a/tileops/kernels/attention/gqa_fwd.py b/tileops/kernels/attention/gqa_fwd.py index 0a60b145..eef41333 100644 --- a/tileops/kernels/attention/gqa_fwd.py +++ b/tileops/kernels/attention/gqa_fwd.py @@ -20,6 +20,7 @@ 'GQAFwdKernel', 'GQAFwdWgmmaPipelinedKernel', 'GQAPrefillFwdKernel', + 'GQAPrefillPagedWithFP8KVCacheFwdKernel', 'GQAPrefillPagedWithKVCacheFwdKernel', 'GQAPrefillPagedWithKVCacheRopeAppendKernel', 'GQAPrefillPagedWithKVCacheRopeFwdKernel', @@ -2109,6 +2110,378 @@ def forward(self, q: torch.Tensor, k_new: torch.Tensor, v_new: torch.Tensor, k_pages, v_pages, cu_seqlens_q, cache_seqlens, block_table) +@functools.lru_cache(maxsize=32) +def _gqa_prefill_paged_with_fp8_kv_cache_fwd_kernel(batch: int, + heads: int, + heads_kv: int, + total_q: int, + physical_tokens: int, + max_pages_per_req: int, + page_size: int, + dim: int, + is_causal: bool, + sm_scale: Optional[float] = None, + softcap: float = 0.0, + dtype: str = 'float16') -> Callable: + score_scale = dim**-0.5 if sm_scale is None else sm_scale + use_softcap = softcap > 0.0 + scale = LOG2E if use_softcap else score_scale * LOG2E + if heads % heads_kv != 0: + raise ValueError("heads must be divisible by heads_kv") + if page_size <= 0 or page_size & (page_size - 1) != 0: + raise ValueError("page_size must be a positive power of two") + groups = heads // heads_kv + accum_dtype = "float" + cache_dtype = T.float8_e4m3fn + + @tilelang.jit( + out_idx=[10, 11], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + compile_flags=["-O3", "-DENABLE_BF16"]) + def _gqa_prefill_paged_with_fp8_kv_cache_fwd_func( + block_m: int, block_n: int, num_stages: int, threads: int) -> Callable: + + q_shape = (total_q, heads, dim) + kv_new_shape = (total_q, heads_kv, dim) + kv_pages_shape = (physical_tokens, heads_kv, dim) + block_table_shape = (batch, max_pages_per_req) + o_shape = (total_q, heads, dim) + online_softmax = make_online_softmax_with_mask_guard( + scale, accum_dtype, block_m, block_n) + apply_softcap = make_apply_softcap( + score_scale, softcap, accum_dtype, block_m, block_n) if use_softcap else None + rescale = make_rescale(block_m, dim) + page_size_log2 = page_size.bit_length() - 1 + fp8_min = -448.0 + fp8_max = 448.0 + + @T.macro + def quantize_fp8(value, scale_value): + return T.clamp(T.Cast("float32", value) / scale_value, fp8_min, fp8_max) + + @T.macro + def dequantize_fp8(value, scale_value): + return T.Cast(dtype, T.Cast("float32", value) * scale_value) + + @T.prim_func + def _gqa_prefill_paged_with_fp8_kv_cache_fwd_main( + q: T.Tensor(q_shape, dtype), # type: ignore + k_new: T.Tensor(kv_new_shape, dtype), # type: ignore + v_new: T.Tensor(kv_new_shape, dtype), # type: ignore + k_pages: T.Tensor(kv_pages_shape, cache_dtype), # type: ignore + v_pages: T.Tensor(kv_pages_shape, cache_dtype), # type: ignore + k_scale: T.Tensor([1], T.float32), # type: ignore + v_scale: T.Tensor([1], T.float32), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cache_seqlens: T.Tensor([batch], T.int32), # type: ignore + block_table: T.Tensor(block_table_shape, T.int32), # type: ignore + output: T.Tensor(o_shape, dtype), # type: ignore + lse: T.Tensor([heads, total_q], accum_dtype), # type: ignore + max_seqlen_q: T.int32, # type: ignore + ) -> None: + with T.Kernel( + T.ceildiv(max_seqlen_q, block_m), heads, batch, threads=threads) as ( + bx, by, bz): + q_shared = T.alloc_shared([block_m, dim], dtype) + k_shared = T.alloc_shared([block_n, dim], dtype) + v_shared = T.alloc_shared([block_n, dim], dtype) + acc_s = T.alloc_fragment([block_m, block_n], accum_dtype) + acc_s_cast = T.alloc_fragment([block_m, block_n], dtype) + acc_o = T.alloc_fragment([block_m, dim], accum_dtype) + scores_max = T.alloc_fragment([block_m], accum_dtype) + scores_max_prev = T.alloc_fragment([block_m], accum_dtype) + scores_scale = T.alloc_fragment([block_m], accum_dtype) + scores_sum = T.alloc_fragment([block_m], accum_dtype) + logsum = T.alloc_fragment([block_m], accum_dtype) + + q_start = cu_seqlens_q[bz] + q_len = cu_seqlens_q[bz + 1] - q_start + old_len = cache_seqlens[bz] + total_len = old_len + q_len + cur_kv_head = by // groups + + if bx * block_m + block_m <= q_len: + T.copy( + q[q_start + bx * block_m:q_start + (bx + 1) * block_m, by, :], + q_shared, + disable_tma=True) + else: + for i, d in T.Parallel(block_m, dim): + new_pos = bx * block_m + i + if new_pos < q_len: + q_shared[i, d] = q[q_start + new_pos, by, d] + else: + q_shared[i, d] = T.cast(0, dtype) + + if by < heads_kv: + append_start = old_len + bx * block_m + append_end = append_start + block_m + if bx * block_m + block_m <= q_len: + if append_start >> T.int32(page_size_log2) == ( + append_end - 1) >> T.int32(page_size_log2): + page_idx = append_start >> T.int32(page_size_log2) + page_offset = append_start - page_idx * page_size + physical_start = block_table[bz, page_idx] * page_size + page_offset + for i, d in T.Parallel(block_m, dim): + k_pages[physical_start + i, by, d] = quantize_fp8( + k_new[q_start + bx * block_m + i, by, d], k_scale[0]) + v_pages[physical_start + i, by, d] = quantize_fp8( + v_new[q_start + bx * block_m + i, by, d], v_scale[0]) + else: + for i, d in T.Parallel(block_m, dim): + new_pos = bx * block_m + i + logical_pos = old_len + new_pos + page_idx = logical_pos >> T.int32(page_size_log2) + page_offset = logical_pos - page_idx * page_size + physical_pos = block_table[bz, page_idx] * page_size + page_offset + k_pages[physical_pos, by, d] = quantize_fp8( + k_new[q_start + new_pos, by, d], k_scale[0]) + v_pages[physical_pos, by, d] = quantize_fp8( + v_new[q_start + new_pos, by, d], v_scale[0]) + else: + for i, d in T.Parallel(block_m, dim): + new_pos = bx * block_m + i + safe_new_pos = T.if_then_else(new_pos < q_len, new_pos, 0) + logical_pos = old_len + safe_new_pos + page_idx = logical_pos >> T.int32(page_size_log2) + page_offset = logical_pos - page_idx * page_size + if new_pos < q_len: + physical_pos = block_table[bz, page_idx] * page_size + page_offset + k_pages[physical_pos, by, d] = quantize_fp8( + k_new[q_start + new_pos, by, d], k_scale[0]) + v_pages[physical_pos, by, d] = quantize_fp8( + v_new[q_start + new_pos, by, d], v_scale[0]) + + T.clear(acc_o) + T.clear(logsum) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.ceildiv(old_len + (bx + 1) * block_m, block_n) + if is_causal else T.ceildiv(total_len, block_n)) + + for k_idx in T.Pipelined(loop_range, num_stages=num_stages): + tile_start = k_idx * block_n + tile_end = tile_start + block_n + if tile_end <= old_len: + if page_size % block_n == 0: + page_idx = tile_start >> T.int32(page_size_log2) + page_offset = tile_start - page_idx * page_size + physical_start = block_table[bz, page_idx] * page_size + page_offset + for j, d in T.Parallel(block_n, dim): + k_shared[j, d] = dequantize_fp8( + k_pages[physical_start + j, cur_kv_head, d], k_scale[0]) + v_shared[j, d] = dequantize_fp8( + v_pages[physical_start + j, cur_kv_head, d], v_scale[0]) + elif block_n % page_size == 0: + tile_page_start = tile_start >> T.int32(page_size_log2) + for p in range(block_n // page_size): + segment_physical_start = block_table[ + bz, tile_page_start + p] * page_size + for off, d in T.Parallel(page_size, dim): + shared_row = p * page_size + off + k_shared[shared_row, d] = dequantize_fp8( + k_pages[segment_physical_start + off, cur_kv_head, d], + k_scale[0]) + v_shared[shared_row, d] = dequantize_fp8( + v_pages[segment_physical_start + off, cur_kv_head, d], + v_scale[0]) + else: + for j, d in T.Parallel(block_n, dim): + kv_pos = tile_start + j + page_idx = kv_pos >> T.int32(page_size_log2) + page_offset = kv_pos - page_idx * page_size + physical_pos = block_table[bz, page_idx] * page_size + page_offset + k_shared[j, d] = dequantize_fp8( + k_pages[physical_pos, cur_kv_head, d], k_scale[0]) + v_shared[j, d] = dequantize_fp8( + v_pages[physical_pos, cur_kv_head, d], v_scale[0]) + elif tile_start >= old_len and tile_end <= total_len: + new_start = tile_start - old_len + for j, d in T.Parallel(block_n, dim): + k_shared[j, d] = k_new[q_start + new_start + j, cur_kv_head, d] + v_shared[j, d] = v_new[q_start + new_start + j, cur_kv_head, d] + else: + for j, d in T.Parallel(block_n, dim): + kv_pos = tile_start + j + new_pos = kv_pos - old_len + safe_kv_pos = T.if_then_else(kv_pos < old_len, kv_pos, 0) + page_idx = safe_kv_pos >> T.int32(page_size_log2) + page_offset = safe_kv_pos - page_idx * page_size + physical_pos = block_table[bz, page_idx] * page_size + page_offset + if kv_pos < old_len: + k_shared[j, d] = dequantize_fp8( + k_pages[physical_pos, cur_kv_head, d], k_scale[0]) + v_shared[j, d] = dequantize_fp8( + v_pages[physical_pos, cur_kv_head, d], v_scale[0]) + elif kv_pos < total_len: + k_shared[j, d] = k_new[q_start + new_pos, cur_kv_head, d] + v_shared[j, d] = v_new[q_start + new_pos, cur_kv_head, d] + else: + k_shared[j, d] = T.cast(0, dtype) + v_shared[j, d] = T.cast(0, dtype) + if is_causal: + for i, j in T.Parallel(block_m, block_n): + kv_pos = k_idx * block_n + j + q_abs_pos = old_len + bx * block_m + i + valid = (bx * block_m + i < q_len) & (kv_pos < total_len) & ( + kv_pos <= q_abs_pos) + acc_s[i, j] = T.if_then_else(valid, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_m, block_n): + kv_pos = k_idx * block_n + j + valid = (bx * block_m + i < q_len) & (kv_pos < total_len) + acc_s[i, j] = T.if_then_else(valid, 0, -T.infinity(acc_s.dtype)) + T.gemm( + q_shared, + k_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow) + if use_softcap: + apply_softcap(acc_s) + online_softmax(acc_s, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + T.copy(acc_s, acc_s_cast) + rescale(acc_o, scores_scale) + T.gemm(acc_s_cast, v_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_m, dim): + if bx * block_m + i < q_len: + output[q_start + bx * block_m + i, by, j] = acc_o[i, j] / logsum[i] + for i in T.Parallel(block_m): + if bx * block_m + i < q_len: + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + lse[by, q_start + bx * block_m + i] = logsum[i] + + return _gqa_prefill_paged_with_fp8_kv_cache_fwd_main + + return _gqa_prefill_paged_with_fp8_kv_cache_fwd_func + + +@torch.library.custom_op( + "top::gqa_prefill_paged_with_fp8_kv_cache_fwd_wrapped_kernel", + mutates_args=("k_pages", "v_pages"), +) +def _gqa_prefill_paged_with_fp8_kv_cache_fwd_wrapped_kernel( + batch: int, + heads: int, + heads_kv: int, + total_q: int, + physical_tokens: int, + max_pages_per_req: int, + page_size: int, + dim: int, + is_causal: bool, + sm_scale: float, + softcap: float, + dtype: str, + block_m: int, + block_n: int, + num_stages: int, + threads: int, + max_seqlen_q: int, + q: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cache_seqlens: torch.Tensor, + block_table: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _gqa_prefill_paged_with_fp8_kv_cache_fwd_kernel( + batch, heads, heads_kv, total_q, physical_tokens, max_pages_per_req, page_size, dim, + is_causal, sm_scale, softcap, dtype)(block_m, block_n, num_stages, threads)( + q, k_new, v_new, k_pages, v_pages, k_scale, v_scale, cu_seqlens_q, cache_seqlens, + block_table, max_seqlen_q) + + +@_gqa_prefill_paged_with_fp8_kv_cache_fwd_wrapped_kernel.register_fake +def _(batch: int, heads: int, heads_kv: int, total_q: int, physical_tokens: int, + max_pages_per_req: int, page_size: int, dim: int, is_causal: bool, sm_scale: float, + softcap: float, dtype: str, block_m: int, block_n: int, num_stages: int, threads: int, + max_seqlen_q: int, + *inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, torch.Tensor]: + fake_o = torch.empty_like(inputs[0]) + fake_lse = fake_o.new_empty([heads, total_q]) + return fake_o, fake_lse + + +class GQAPrefillPagedWithFP8KVCacheFwdKernel(Kernel): + supported_archs: list[int] = [89, 90] + + def __init__(self, + batch: int, + heads: int, + heads_kv: int, + max_pages_per_req: int, + page_size: int, + dim: int, + is_causal: bool, + dtype: torch.dtype, + sm_scale: Optional[float] = None, + softcap: float = 0.0, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.batch = batch + self.heads = heads + if heads % heads_kv != 0: + raise ValueError("heads must be divisible by heads_kv") + self.heads_kv = heads_kv + self.max_pages_per_req = max_pages_per_req + self.page_size = page_size + self.dim = dim + self.is_causal = is_causal + self.dtype = dtype + self.sm_scale = dim**-0.5 if sm_scale is None else sm_scale + self.softcap = softcap + + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_m": 64, + "block_n": 64 if self.dim <= 128 else 32, + "num_stages": 1, + "threads": 128 + } + + @property + def autotune_configs(self) -> list[dict]: + block_m = [32, 64, 128] + block_n = [32, 64, 128] + num_stages = [1, 2, 3] + threads = [128, 256] + _configs = list(itertools.product(block_m, block_n, num_stages, threads)) + + return [{ + 'block_m': c[0], + 'block_n': c[1], + 'num_stages': c[2], + 'threads': c[3] + } for c in _configs] + + def forward(self, q: torch.Tensor, k_new: torch.Tensor, v_new: torch.Tensor, + k_pages: torch.Tensor, v_pages: torch.Tensor, k_scale: torch.Tensor, + v_scale: torch.Tensor, cu_seqlens_q: torch.Tensor, cache_seqlens: torch.Tensor, + block_table: torch.Tensor, max_seqlen_q: int) -> Tuple[torch.Tensor, + torch.Tensor]: + return _gqa_prefill_paged_with_fp8_kv_cache_fwd_wrapped_kernel( + self.batch, self.heads, self.heads_kv, q.shape[0], k_pages.shape[0], + self.max_pages_per_req, self.page_size, self.dim, self.is_causal, self.sm_scale, + self.softcap, self.dtype_str, self.config["block_m"], self.config["block_n"], + self.config["num_stages"], self.config["threads"], max_seqlen_q, q, k_new, v_new, + k_pages, v_pages, k_scale, v_scale, cu_seqlens_q, cache_seqlens, block_table) + + @functools.lru_cache(maxsize=32) def _gqa_prefill_paged_with_kv_cache_rope_append_kernel(batch: int, heads_kv: int, diff --git a/tileops/manifest/attention.yaml b/tileops/manifest/attention.yaml index 48592a27..e6fa715e 100644 --- a/tileops/manifest/attention.yaml +++ b/tileops/manifest/attention.yaml @@ -299,6 +299,70 @@ GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp: test: tests/ops/attention/test_gqa_prefill_paged.py bench: benchmarks/ops/attention/bench_gqa.py +GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp: + ref_api: "none" + family: attention + status: implemented + + signature: + inputs: + q: {dtype: "float16 | bfloat16"} + k_new: {dtype: "same_as(q)"} + v_new: {dtype: "same_as(q)"} + k_pages: {dtype: "float8_e4m3fn"} + v_pages: {dtype: "float8_e4m3fn"} + k_scale: {dtype: "float32"} + v_scale: {dtype: "float32"} + cu_seqlens_q: {dtype: "int32"} + cache_seqlens: {dtype: "int32"} + block_table: {dtype: "int32"} + outputs: + o: {dtype: "same_as(q)"} + params: + max_pages_per_req: {type: int} + page_size: {type: int} + max_seqlen_q: {type: int} + is_causal: {type: bool, default: true} + sm_scale: {type: float, default: null} + softcap: {type: float, default: null} + cache_dtype: {type: dtype, default: float8_e4m3fn} + shape_rules: + - "q.shape == (total_q, H, D)" + - "k_new.shape == (total_q, H_kv, D)" + - "v_new.shape == (total_q, H_kv, D)" + - "k_pages.shape == (physical_tokens, H_kv, D)" + - "v_pages.shape == (physical_tokens, H_kv, D)" + - "k_scale.shape == (1,)" + - "v_scale.shape == (1,)" + - "cu_seqlens_q.shape == (B + 1,)" + - "cache_seqlens.shape == (B,)" + - "block_table.shape == (B, max_pages_per_req)" + - "o.shape == (total_q, H, D)" + - "H % H_kv == 0" + - "page_size > 0" + - "max_pages_per_req > 0" + - "max_seqlen_q > 0" + - "physical_tokens % page_size == 0" + + workloads: + # Storage-only FP8: the current chunk participates in attention as same_as(q); + # old cache pages are dequantized online with scalar k_scale/v_scale. RoPE is + # external in this entry; fused RoPE + FP8 append is a follow-up kernel variant. + - {total_q: 8192, batch: 8, heads: 32, heads_kv: 8, dim: 256, page_size: 64, max_pages_per_req: 528, max_seqlen_q: 1024, physical_tokens: 270336, is_causal: true, dtypes: [float16], cache_dtype: float8_e4m3fn, label: "qwen35-9b-prefill-paged-fp8-cache-b8-prefix32k-chunk1k-p64-fp16"} + - {total_q: 4096, batch: 8, heads: 32, heads_kv: 8, dim: 128, page_size: 64, max_pages_per_req: 72, max_seqlen_q: 512, physical_tokens: 36864, is_causal: true, dtypes: [float16], cache_dtype: float8_e4m3fn, label: "llama31-8b-prefill-paged-fp8-cache-b8-prefix4k-chunk512-p64-fp16"} + - {total_q: 2048, batch: 4, heads: 8, heads_kv: 2, dim: 64, page_size: 64, max_pages_per_req: 72, max_seqlen_q: 512, physical_tokens: 18432, is_causal: true, softcap: 50.0, dtypes: [float16], cache_dtype: float8_e4m3fn, label: "gqa-prefill-paged-fp8-cache-softcap50-b4-prefix4k-chunk512-p64-fp16"} + + roofline: + func: "tileops.perf.formulas.gqa_prefill_with_kv_cache_fwd_roofline" + + source: + kernel: tileops/kernels/attention/gqa_fwd.py + kernel_map: + gqa_prefill_paged_with_fp8_kv_cache_fwd_kernel: GQAPrefillPagedWithFP8KVCacheFwdKernel + op: tileops/ops/attention/gqa.py + test: tests/ops/attention/test_gqa_prefill_paged.py + bench: benchmarks/ops/attention/bench_gqa.py + GroupedQueryAttentionBwdOp: ref_api: "torch.nn.functional.scaled_dot_product_attention" family: attention diff --git a/tileops/ops/__init__.py b/tileops/ops/__init__.py index f20eed52..d93e0260 100644 --- a/tileops/ops/__init__.py +++ b/tileops/ops/__init__.py @@ -5,6 +5,7 @@ GroupedQueryAttentionDecodeWithKVCacheFwdOp, GroupedQueryAttentionFwdOp, GroupedQueryAttentionPrefillFwdOp, + GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp, GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp, GroupedQueryAttentionPrefillVarlenFwdOp, GroupedQueryAttentionPrefillWithKVCacheFwdOp, @@ -137,6 +138,7 @@ "GroupedQueryAttentionDecodeWithKVCacheFwdOp", "GroupedQueryAttentionFwdOp", "GroupedQueryAttentionPrefillFwdOp", + "GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp", "GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp", "GroupedQueryAttentionPrefillVarlenFwdOp", "GroupedQueryAttentionPrefillWithKVCacheFwdOp", diff --git a/tileops/ops/attention/__init__.py b/tileops/ops/attention/__init__.py index 997be3f9..07dde469 100644 --- a/tileops/ops/attention/__init__.py +++ b/tileops/ops/attention/__init__.py @@ -12,6 +12,7 @@ GroupedQueryAttentionDecodeWithKVCacheFwdOp, GroupedQueryAttentionFwdOp, GroupedQueryAttentionPrefillFwdOp, + GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp, GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp, GroupedQueryAttentionPrefillVarlenFwdOp, GroupedQueryAttentionPrefillWithKVCacheFwdOp, @@ -32,6 +33,7 @@ "GroupedQueryAttentionDecodeWithKVCacheFwdOp", "GroupedQueryAttentionFwdOp", "GroupedQueryAttentionPrefillFwdOp", + "GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp", "GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp", "GroupedQueryAttentionPrefillVarlenFwdOp", "GroupedQueryAttentionPrefillWithKVCacheFwdOp", diff --git a/tileops/ops/attention/gqa.py b/tileops/ops/attention/gqa.py index 11d86865..e1241524 100644 --- a/tileops/ops/attention/gqa.py +++ b/tileops/ops/attention/gqa.py @@ -16,6 +16,7 @@ GQAFwdWsPersistentCausalKernel, GQAFwdWsPersistentKernel, GQAPrefillFwdKernel, + GQAPrefillPagedWithFP8KVCacheFwdKernel, GQAPrefillPagedWithKVCacheFwdKernel, GQAPrefillPagedWithKVCacheRopeAppendKernel, GQAPrefillPagedWithKVCacheRopeFwdKernel, @@ -40,6 +41,7 @@ "GroupedQueryAttentionDecodeWithKVCacheFwdOp", "GroupedQueryAttentionFwdOp", "GroupedQueryAttentionPrefillFwdOp", + "GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp", "GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp", "GroupedQueryAttentionPrefillVarlenFwdOp", "GroupedQueryAttentionPrefillWithKVCacheFwdOp", @@ -154,6 +156,10 @@ def _select_gqa_prefill_paged_with_kv_cache_fwd_kernel_cls() -> Type[Kernel]: return GQAPrefillPagedWithKVCacheFwdKernel +def _select_gqa_prefill_paged_with_fp8_kv_cache_fwd_kernel_cls() -> Type[Kernel]: + return GQAPrefillPagedWithFP8KVCacheFwdKernel + + def _select_gqa_prefill_paged_with_kv_cache_rope_fwd_kernel_cls() -> Type[Kernel]: return GQAPrefillPagedWithKVCacheRopeFwdKernel @@ -944,6 +950,238 @@ def total_memory(self) -> int: "compute per-sample from cu_seqlens and cache_seqlens at call time.") +class GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp(Op): + """Packed GQA prefill with paged FP8 KV cache append. Layout: THD. + + This storage-only FP8 variant keeps ``q``, ``k_new`` and ``v_new`` in the + attention dtype, dequantizes old cache pages online with scalar + ``k_scale`` / ``v_scale``, and appends the current chunk quantized into + FP8 E4M3 pages for subsequent requests. + """ + + def __init__( + self, + batch: int, + heads: int, + heads_kv: int, + max_pages_per_req: int, + page_size: int, + dim: int, + is_causal: bool = True, + dtype: torch.dtype = torch.float16, + cache_dtype: torch.dtype = torch.float8_e4m3fn, + sm_scale: Optional[float] = None, + softcap: Optional[float] = None, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune: bool = False, + ) -> None: + _validate_gqa_dims(heads, heads_kv, dim) + _validate_attention_dtype(dtype) + if cache_dtype != torch.float8_e4m3fn: + raise ValueError( + "GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp currently supports " + f"torch.float8_e4m3fn cache only, got {cache_dtype}") + if batch <= 0: + raise ValueError("batch must be positive") + if max_pages_per_req <= 0: + raise ValueError("max_pages_per_req must be positive") + if page_size <= 0: + raise ValueError("page_size must be positive") + if page_size & (page_size - 1) != 0: + raise ValueError("page_size must be a power of two") + self.batch = batch + self.heads = heads + self.heads_kv = heads_kv + self.groups = heads // heads_kv + self.max_pages_per_req = max_pages_per_req + self.page_size = page_size + self.max_cache_len = max_pages_per_req * page_size + self.dim = dim + self.is_causal = is_causal + self.dtype = dtype + self.cache_dtype = cache_dtype + self.sm_scale = _attention_scale(dim, sm_scale) + self.softcap = _score_softcap(softcap) + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["gqa_prefill_paged_with_fp8_kv_cache_fwd_kernel"]( + batch=batch, + heads=heads, + heads_kv=heads_kv, + max_pages_per_req=max_pages_per_req, + page_size=page_size, + dim=dim, + is_causal=is_causal, + dtype=dtype, + sm_scale=self.sm_scale, + softcap=self.softcap, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return { + "gqa_prefill_paged_with_fp8_kv_cache_fwd_kernel": + _select_gqa_prefill_paged_with_fp8_kv_cache_fwd_kernel_cls() + } + + def _validate_forward_inputs( + self, + q: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cache_seqlens: torch.Tensor, + block_table: torch.Tensor, + max_seqlen_q: int, + ) -> None: + tensors = { + "q": q, + "k_new": k_new, + "v_new": v_new, + "k_pages": k_pages, + "v_pages": v_pages, + "k_scale": k_scale, + "v_scale": v_scale, + "cu_seqlens_q": cu_seqlens_q, + "cache_seqlens": cache_seqlens, + "block_table": block_table, + } + for name, tensor in tensors.items(): + if tensor.device.type != "cuda": + raise ValueError(f"{name} must be on a cuda device, got {tensor.device}") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + expected_q_shape_tail = (self.heads, self.dim) + expected_kv_shape_tail = (self.heads_kv, self.dim) + if q.ndim != 3 or tuple(q.shape[1:]) != expected_q_shape_tail: + raise ValueError( + f"q must have shape [total_q, {self.heads}, {self.dim}], got {q.shape}") + if k_new.ndim != 3 or tuple(k_new.shape[1:]) != expected_kv_shape_tail: + raise ValueError( + f"k_new must have shape [total_q, {self.heads_kv}, {self.dim}], got " + f"{k_new.shape}") + if v_new.shape != k_new.shape: + raise ValueError( + f"v_new must have the same shape as k_new, got {v_new.shape} and " + f"{k_new.shape}") + if k_new.shape[0] != q.shape[0]: + raise ValueError( + f"k_new.shape[0] ({k_new.shape[0]}) must equal q.shape[0] ({q.shape[0]})") + if k_pages.ndim != 3 or tuple(k_pages.shape[1:]) != expected_kv_shape_tail: + raise ValueError( + f"k_pages must have shape [physical_tokens, {self.heads_kv}, {self.dim}], " + f"got {k_pages.shape}") + if v_pages.shape != k_pages.shape: + raise ValueError( + f"v_pages must have the same shape as k_pages, got {v_pages.shape} and " + f"{k_pages.shape}") + if k_pages.shape[0] % self.page_size != 0: + raise ValueError("k_pages physical token dimension must be divisible by page_size") + if k_scale.shape != (1,) or v_scale.shape != (1,): + raise ValueError( + f"k_scale and v_scale must have shape (1,), got {k_scale.shape} and " + f"{v_scale.shape}") + if cu_seqlens_q.shape != (self.batch + 1,): + raise ValueError( + f"cu_seqlens_q shape must be ({self.batch + 1},), got " + f"{tuple(cu_seqlens_q.shape)}") + if cache_seqlens.shape != (self.batch,): + raise ValueError( + f"cache_seqlens shape must be ({self.batch},), got " + f"{tuple(cache_seqlens.shape)}") + if block_table.shape != (self.batch, self.max_pages_per_req): + raise ValueError( + f"block_table shape must be ({self.batch}, {self.max_pages_per_req}), " + f"got {tuple(block_table.shape)}") + + for name, tensor in [("q", q), ("k_new", k_new), ("v_new", v_new)]: + if tensor.dtype != self.dtype: + raise ValueError(f"Expected {name}.dtype {self.dtype}, got {tensor.dtype}") + for name, tensor in [("k_pages", k_pages), ("v_pages", v_pages)]: + if tensor.dtype != self.cache_dtype: + raise ValueError( + f"Expected {name}.dtype {self.cache_dtype}, got {tensor.dtype}") + for name, tensor in [("k_scale", k_scale), ("v_scale", v_scale)]: + if tensor.dtype != torch.float32: + raise ValueError(f"{name} must have dtype torch.float32, got {tensor.dtype}") + for name, tensor in [("cu_seqlens_q", cu_seqlens_q), + ("cache_seqlens", cache_seqlens), + ("block_table", block_table)]: + if tensor.dtype != torch.int32: + raise ValueError(f"{name} must have dtype torch.int32, got {tensor.dtype}") + + if int(cu_seqlens_q[0].item()) != 0: + raise ValueError("cu_seqlens_q[0] must be 0") + q_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + if torch.any(q_lens < 0).item(): + raise ValueError("cu_seqlens_q must be non-decreasing") + total_q = int(cu_seqlens_q[-1].item()) + if total_q != q.shape[0]: + raise ValueError(f"cu_seqlens_q[-1] ({total_q}) must equal q.shape[0] ({q.shape[0]})") + actual_max_q = int(q_lens.max().item()) + if max_seqlen_q < actual_max_q: + raise ValueError( + f"max_seqlen_q ({max_seqlen_q}) must be >= actual max Q " + f"sequence length ({actual_max_q})") + + min_cache_len = int(cache_seqlens.min().item()) + max_total_len = int((cache_seqlens + q_lens).max().item()) + if min_cache_len < 0: + raise ValueError("cache_seqlens must be non-negative") + if max_total_len > self.max_cache_len: + raise ValueError( + "cache_seqlens + q_len exceeds paged KV capacity: " + f"max total length {max_total_len}, capacity {self.max_cache_len}") + + num_pages = k_pages.shape[0] // self.page_size + min_page = int(block_table.min().item()) + max_page = int(block_table.max().item()) + if min_page < 0: + raise ValueError("block_table must contain non-negative physical page ids") + if max_page >= num_pages: + raise ValueError( + f"block_table references page {max_page}, but only {num_pages} pages exist") + + def forward( + self, + q: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cache_seqlens: torch.Tensor, + block_table: torch.Tensor, + max_seqlen_q: int, + ) -> torch.Tensor: + self._validate_forward_inputs( + q, k_new, v_new, k_pages, v_pages, k_scale, v_scale, cu_seqlens_q, + cache_seqlens, block_table, max_seqlen_q) + return _attention_output( + self.kernel(q, k_new, v_new, k_pages, v_pages, k_scale, v_scale, cu_seqlens_q, + cache_seqlens, block_table, max_seqlen_q)) + + @property + def total_flops(self) -> int: + raise NotImplementedError( + "total_flops is not defined for paged varlen ops; " + "compute per-sample from cu_seqlens and cache_seqlens at call time.") + + @property + def total_memory(self) -> int: + raise NotImplementedError( + "total_memory is not defined for paged varlen ops; " + "compute per-sample from cu_seqlens and cache_seqlens at call time.") + + class GroupedQueryAttentionBwdOp(Op): """Layout: BSHD"""