Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 106 additions & 1 deletion tests/ops/attention/test_gqa_prefill_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import pytest
import torch

from tileops.ops import GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp, RopeNeoxPositionIdsOp
from tileops.ops import (
GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp,
GroupedQueryAttentionPrefillPagedWithKVCacheFwdOp,
RopeNeoxPositionIdsOp,
)

_PREFILL_PAGED_TOLERANCE = {
torch.float16: (5e-3, 1e-5),
Expand Down Expand Up @@ -206,6 +210,107 @@ def test_gqa_prefill_paged_with_kv_cache_fwd(
torch.testing.assert_close(v_pages[physical_pos], v_pages_before[physical_pos])


@pytest.mark.smoke
@pytest.mark.parametrize("is_causal, softcap, dtype, page_size", [
pytest.param(True, None, torch.float16, 64, id="causal-fp16-page64"),
pytest.param(False, None, torch.float16, 64, id="noncausal-fp16-page64"),
pytest.param(True, 2.0, torch.float16, 64, id="causal-softcap-fp16-page64"),
pytest.param(True, None, torch.bfloat16, 64, id="causal-bf16-page64"),
pytest.param(True, None, torch.float16, 16, id="causal-fp16-page16"),
pytest.param(True, None, torch.float16, 128, id="causal-fp16-page128"),
])
def test_gqa_prefill_paged_with_fp8_kv_cache_fwd(
is_causal: bool,
softcap: float | None,
dtype: torch.dtype,
page_size: int,
) -> None:
q_lens = [33, 48]
old_lens = [67, 80]
batch, heads, heads_kv, dim = 2, 8, 2, 64
cache_dtype = torch.float8_e4m3fn
max_pages_per_req = 8
num_pages = batch * max_pages_per_req
total_q = sum(q_lens)
block_table = _make_block_table(batch, max_pages_per_req)
cu_seqlens_q = _make_cu_seqlens(q_lens)
cache_seqlens = torch.tensor(old_lens, device="cuda", dtype=torch.int32)
k_scale = torch.tensor([0.02], device="cuda", dtype=torch.float32)
v_scale = torch.tensor([0.02], device="cuda", dtype=torch.float32)

q = torch.randn(total_q, heads, dim, device="cuda", dtype=dtype).contiguous()
k_new = (torch.randn(total_q, heads_kv, dim, device="cuda", dtype=dtype) *
0.5).contiguous()
v_new = (torch.randn(total_q, heads_kv, dim, device="cuda", dtype=dtype) *
0.5).contiguous()
k_pages = torch.zeros(num_pages * page_size, heads_kv, dim, device="cuda",
dtype=cache_dtype).contiguous()
v_pages = torch.zeros_like(k_pages)
k_old = [
(torch.randn(old_len, heads_kv, dim, device="cuda", dtype=dtype) * 0.5).contiguous()
for old_len in old_lens
]
v_old = [
(torch.randn(old_len, heads_kv, dim, device="cuda", dtype=dtype) * 0.5).contiguous()
for old_len in old_lens
]
k_old_quant = [(k_b.float() / k_scale[0]).to(cache_dtype).contiguous() for k_b in k_old]
v_old_quant = [(v_b.float() / v_scale[0]).to(cache_dtype).contiguous() for v_b in v_old]
_fill_paged_cache_from_logical(
k_pages, v_pages, k_old_quant, v_old_quant, block_table, page_size)
k_pages_before = k_pages.clone()
v_pages_before = v_pages.clone()
k_old_dequant = [(k_b.float() * k_scale[0]).to(dtype).contiguous() for k_b in k_old_quant]
v_old_dequant = [(v_b.float() * v_scale[0]).to(dtype).contiguous() for v_b in v_old_quant]
ref = _gqa_prefill_paged_ref(
q,
k_new,
v_new,
k_old_dequant,
v_old_dequant,
cu_seqlens_q,
batch=batch,
heads=heads,
heads_kv=heads_kv,
is_causal=is_causal,
softcap=softcap,
)
op = GroupedQueryAttentionPrefillPagedWithFP8KVCacheFwdOp(
batch=batch,
heads=heads,
heads_kv=heads_kv,
max_pages_per_req=max_pages_per_req,
page_size=page_size,
dim=dim,
is_causal=is_causal,
dtype=dtype,
softcap=softcap,
)

output = op(
q, k_new, v_new, k_pages, v_pages, k_scale, v_scale, cu_seqlens_q, cache_seqlens,
block_table, max(q_lens))
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(output, ref, atol=8e-2, rtol=2e-2)

for b, (q_len, old_len) in enumerate(zip(q_lens, old_lens, strict=True)):
q_start = int(cu_seqlens_q[b].item())
for i in range(q_len):
physical_pos = _physical_pos(block_table, b, old_len + i, page_size)
expected_k = (k_new[q_start + i].float() / k_scale[0]).to(cache_dtype).float()
expected_v = (v_new[q_start + i].float() / v_scale[0]).to(cache_dtype).float()
torch.testing.assert_close(k_pages[physical_pos].float(), expected_k, atol=0, rtol=0)
torch.testing.assert_close(v_pages[physical_pos].float(), expected_v, atol=0, rtol=0)

for b, old_len in enumerate(old_lens):
for pos in range(old_len):
physical_pos = _physical_pos(block_table, b, pos, page_size)
torch.testing.assert_close(
k_pages[physical_pos].float(), k_pages_before[physical_pos].float())
torch.testing.assert_close(
v_pages[physical_pos].float(), v_pages_before[physical_pos].float())


@pytest.mark.smoke
@pytest.mark.parametrize("rotary_dim, is_causal, softcap", [
pytest.param(None, True, None, id="full-causal"),
Expand Down
2 changes: 2 additions & 0 deletions tileops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
GQAFwdKernel,
GQAFwdWgmmaPipelinedKernel,
GQAPrefillFwdKernel,
GQAPrefillPagedWithFP8KVCacheFwdKernel,
GQAPrefillPagedWithKVCacheFwdKernel,
GQAPrefillPagedWithKVCacheRopeAppendKernel,
GQAPrefillPagedWithKVCacheRopeFwdKernel,
Expand Down Expand Up @@ -106,6 +107,7 @@
"GQAFwdKernel",
"GQAFwdWgmmaPipelinedKernel",
"GQAPrefillFwdKernel",
"GQAPrefillPagedWithFP8KVCacheFwdKernel",
"GQAPrefillPagedWithKVCacheFwdKernel",
"GQAPrefillPagedWithKVCacheRopeAppendKernel",
"GQAPrefillPagedWithKVCacheRopeFwdKernel",
Expand Down
2 changes: 2 additions & 0 deletions tileops/kernels/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
GQAFwdKernel,
GQAFwdWgmmaPipelinedKernel,
GQAPrefillFwdKernel,
GQAPrefillPagedWithFP8KVCacheFwdKernel,
GQAPrefillPagedWithKVCacheFwdKernel,
GQAPrefillPagedWithKVCacheRopeAppendKernel,
GQAPrefillPagedWithKVCacheRopeFwdKernel,
Expand Down Expand Up @@ -52,6 +53,7 @@
"GQAFwdWsPersistentCausalKernel",
"GQAFwdWsPersistentKernel",
"GQAPrefillFwdKernel",
"GQAPrefillPagedWithFP8KVCacheFwdKernel",
"GQAPrefillPagedWithKVCacheFwdKernel",
"GQAPrefillPagedWithKVCacheRopeAppendKernel",
"GQAPrefillPagedWithKVCacheRopeFwdKernel",
Expand Down
Loading
Loading