diff --git a/benchmark/test_vllm_perf.py b/benchmark/test_vllm_perf.py index 888db878fc..4bd0ee024a 100644 --- a/benchmark/test_vllm_perf.py +++ b/benchmark/test_vllm_perf.py @@ -18,6 +18,21 @@ def is_vllm_available(): return False +try: + from vllm.utils.deep_gemm import fp8_paged_mqa_logits as vllm_fp8_paged_mqa_logits + from vllm.utils.deep_gemm import ( + get_num_sms, + get_paged_mqa_logits_metadata, + has_deep_gemm, + ) + + DEEPGEMM_AVAILABLE = has_deep_gemm() +except ImportError: + DEEPGEMM_AVAILABLE = False + vllm_fp8_paged_mqa_logits = None + get_num_sms = None + get_paged_mqa_logits_metadata = None + VLLM_AVAILABLE = is_vllm_available() @@ -1058,3 +1073,222 @@ def test_get_paged_mqa_logits_metadata_benchmark(): ) bench.set_gems(flag_gems.get_paged_mqa_logits_metadata) bench.run() + + +def kv_cache_cast_to_fp8_deepgemm(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + x_fp8[:, block_size * head_dim :] = sf_bytes + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + out = torch.empty( + (num_blocks, block_size, num_heads, head_dim + 4), + device=x.device, + dtype=torch.uint8, + ) + out[..., :head_dim] = x_scaled.view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) + return out + + +def _build_case( + batch_size, next_n, heads, head_dim, avg_kv, blocksize, q_dtype, max_model_len=4096 +): + num_blocks = max_model_len * 20 + + q = torch.randn( + (batch_size, next_n, heads, head_dim), + device=flag_gems.device, + dtype=q_dtype, + ) + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache = torch.randn( + (num_blocks, blocksize, 1, head_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + + weights = torch.randn( + (batch_size * next_n, heads), + device=flag_gems.device, + dtype=torch.float32, + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size,), + device=flag_gems.device, + dtype=torch.int32, + ) + + max_num_blocks_per_seq = ( + int(context_lens.max().item()) + blocksize - 1 + ) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), + device=flag_gems.device, + dtype=torch.int32, + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i, j] = block_idx_pool[counter] + counter += 1 + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8_deepgemm(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + return ( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + +class FP8PagedMQACompareBenchmark(Benchmark): + def __init__(self): + super().__init__( + "fp8_paged_mqa_logits_gems_vs_deepgemm", + self._vllm_wrapper, + [torch.bfloat16], + ) + self.set_gems(self._gems_wrapper) + + def set_shapes(self, shape_file_path=None): + self.shapes = [] + + def get_input_iter(self, _dtype): + compare_shapes = [ + (1, 1, 16, 64, 1024), + (2, 1, 32, 128, 2048), + (4, 1, 32, 128, 2048), + (2, 2, 32, 128, 2048), + (8, 1, 32, 128, 3072), + ] + q_dtypes = [torch.bfloat16, torch.float16] + blocksize = 64 + + for (bs, nn, h, d, avg_kv), q_dtype in product(compare_shapes, q_dtypes): + case = _build_case(bs, nn, h, d, avg_kv, blocksize, q_dtype) + ( + q_fp8, + kv_dg, + kv_tr, + weights, + context_lens, + block_tables, + max_model_len, + ) = case + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + yield ( + q_fp8, + kv_dg, + kv_tr, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ) + + @staticmethod + def _vllm_wrapper( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ): + return vllm_fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + + @staticmethod + def _gems_wrapper( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ): + return flag_gems.fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + +@pytest.mark.skipif( + not ( + torch.cuda.is_available() + and VLLM_AVAILABLE + and DEEPGEMM_AVAILABLE + and CUDA_AVAILABLE + ), + reason="requires CUDA + vLLM + Hopper", +) +@pytest.mark.performance +@pytest.mark.fp8_paged_mqa_logits +def test_perf_fp8_paged_mqa_logits_gems_vs_deepgemm(): + bench = FP8PagedMQACompareBenchmark() + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 002ce46871..11d2f7d0bd 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -191,6 +191,7 @@ def torch_ge(v): ("floor_divide_.Tensor", floor_divide_), ("fmin", fmin), ("fmin.out", fmin_out), + ("fp8_paged_mqa_logits", fp8_paged_mqa_logits), ("full", full), ("full_like", full_like), ("gather", gather), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index ac589f5340..e988f8b779 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -115,6 +115,7 @@ from flag_gems.ops.flip import flip from flag_gems.ops.floor_ import floor_ from flag_gems.ops.fmin import fmin, fmin_out +from flag_gems.ops.fp8_paged_mqa_logits import fp8_paged_mqa_logits from flag_gems.ops.full import full from flag_gems.ops.full_like import full_like from flag_gems.ops.gather import gather, gather_backward @@ -433,6 +434,7 @@ "floor_divide_", "fmin", "fmin_out", + "fp8_paged_mqa_logits", "full", "full_like", "gather", diff --git a/src/flag_gems/ops/fp8_paged_mqa_logits.py b/src/flag_gems/ops/fp8_paged_mqa_logits.py new file mode 100644 index 0000000000..6c276c414a --- /dev/null +++ b/src/flag_gems/ops/fp8_paged_mqa_logits.py @@ -0,0 +1,327 @@ +import torch +import triton +import triton.language as tl + + +def cdiv(x: int, y: int) -> int: + return (x + y - 1) // y + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 32}, + num_warps=8, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 16}, + num_warps=8, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 8}, + num_warps=4, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 128, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 32}, + num_warps=8, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 128, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 16}, + num_warps=8, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 32}, + num_warps=4, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 2, "BLOCK_H": 16}, + num_warps=8, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 2, "BLOCK_H": 8}, + num_warps=4, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 128, "BLOCK_D": 64, "NUM_D_TILES": 2, "BLOCK_H": 16}, + num_warps=8, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 1, "BLOCK_H": 16}, + num_warps=4, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 1, "BLOCK_H": 8}, + num_warps=4, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 128, "BLOCK_D": 64, "NUM_D_TILES": 1, "BLOCK_H": 16}, + num_warps=8, + num_stages=2, + ), + ], + key=["heads", "dim", "block_size"], +) +@triton.jit +def fp8_paged_mqa_logits_kernel( + q_ptr, + kv_ptr, + weights_ptr, + logits_ptr, + block_tables_ptr, + context_lens_ptr, + stride_qb, + stride_qn, + stride_qh, + stride_qd, + stride_kvblk, + stride_kvpos, + stride_kvone, + stride_kvbyte, + stride_wrow, + stride_wh, + stride_lrow, + stride_lcol, + stride_btb, + stride_bts, + next_n: tl.constexpr, + heads: tl.constexpr, + dim: tl.constexpr, + block_size: tl.constexpr, + max_model_len, + dim_plus_4: tl.constexpr, + BLOCK_KV: tl.constexpr, + BLOCK_D: tl.constexpr, + NUM_D_TILES: tl.constexpr, + BLOCK_H: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_kv_tile = tl.program_id(1) + + batch_idx = pid_row // next_n + next_n_idx = pid_row % next_n + + context_len = tl.load(context_lens_ptr + batch_idx) + query_seq_pos = context_len - next_n + next_n_idx + + kv_start = pid_kv_tile * BLOCK_KV + if kv_start >= context_len: + offs_kv = tl.arange(0, BLOCK_KV) + kv_pos = kv_start + offs_kv + out_mask = kv_pos < max_model_len + out_ptrs = logits_ptr + pid_row * stride_lrow + kv_pos * stride_lcol + tl.store(out_ptrs, float("-inf"), mask=out_mask) + return + + offs_kv = tl.arange(0, BLOCK_KV) + kv_global_pos = kv_start + offs_kv + + context_mask = kv_global_pos < context_len + causal_mask = kv_global_pos <= query_seq_pos + valid_mask = context_mask & causal_mask + + phys_block_idx = kv_global_pos // block_size + intra_block_pos = kv_global_pos % block_size + + phys_block_ids = tl.load( + block_tables_ptr + batch_idx * stride_btb + phys_block_idx * stride_bts, + mask=valid_mask, + other=0, + ) + + kv_base = phys_block_ids * stride_kvblk + intra_block_pos * stride_kvpos + + scale_addr = kv_base + dim * stride_kvbyte + scale_ptr = (kv_ptr + scale_addr).to(tl.pointer_type(tl.uint32, 1), bitcast=True) + scale_u32 = tl.load(scale_ptr, mask=valid_mask, other=0) + scale_f32 = scale_u32.to(tl.float32, bitcast=True) + + logit_accum = tl.zeros([BLOCK_KV], dtype=tl.float32) + offs_d = tl.arange(0, BLOCK_D) + q_base = q_ptr + batch_idx * stride_qb + next_n_idx * stride_qn + + if NUM_D_TILES == 1: + d_mask = offs_d < dim + + kv_byte_ptrs = kv_ptr + kv_base[:, None] + offs_d[None, :] * stride_kvbyte + load_mask = valid_mask[:, None] & d_mask[None, :] + kv_u8 = tl.load(kv_byte_ptrs, mask=load_mask, other=0) + kv_fp8 = kv_u8.to(tl.float8e4nv, bitcast=True) + kv_f32 = kv_fp8.to(tl.float32) + + for h_tile in tl.static_range(0, heads, BLOCK_H): + offs_h = h_tile + tl.arange(0, BLOCK_H) + h_mask = offs_h < heads + + q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd + q_vals = tl.load( + q_ptrs, mask=h_mask[:, None] & d_mask[None, :], other=0.0 + ).to(tl.float32) + weights = tl.load( + weights_ptr + pid_row * stride_wrow + offs_h * stride_wh, + mask=h_mask, + other=0.0, + ) + + q_tile = tl.trans(q_vals) + partial_dot = tl.dot(kv_f32, q_tile, out_dtype=tl.float32) + partial_dot = partial_dot * scale_f32[:, None] + partial_dot = tl.maximum(partial_dot, 0.0) + logit_accum += tl.sum(partial_dot * weights[None, :], axis=1) + + else: + d_offs0 = offs_d + d_mask0 = d_offs0 < dim + d_offs1 = BLOCK_D + offs_d + d_mask1 = d_offs1 < dim + + kv_byte_ptrs0 = kv_ptr + kv_base[:, None] + d_offs0[None, :] * stride_kvbyte + load_mask0 = valid_mask[:, None] & d_mask0[None, :] + kv_u80 = tl.load(kv_byte_ptrs0, mask=load_mask0, other=0) + kv_fp80 = kv_u80.to(tl.float8e4nv, bitcast=True) + kv_f320 = kv_fp80.to(tl.float32) + + kv_byte_ptrs1 = kv_ptr + kv_base[:, None] + d_offs1[None, :] * stride_kvbyte + load_mask1 = valid_mask[:, None] & d_mask1[None, :] + kv_u81 = tl.load(kv_byte_ptrs1, mask=load_mask1, other=0) + kv_fp81 = kv_u81.to(tl.float8e4nv, bitcast=True) + kv_f321 = kv_fp81.to(tl.float32) + + for h_tile in tl.static_range(0, heads, BLOCK_H): + offs_h = h_tile + tl.arange(0, BLOCK_H) + h_mask = offs_h < heads + + q_ptrs0 = ( + q_base + offs_h[:, None] * stride_qh + d_offs0[None, :] * stride_qd + ) + q_vals0 = tl.load( + q_ptrs0, mask=h_mask[:, None] & d_mask0[None, :], other=0.0 + ).to(tl.float32) + + q_ptrs1 = ( + q_base + offs_h[:, None] * stride_qh + d_offs1[None, :] * stride_qd + ) + q_vals1 = tl.load( + q_ptrs1, mask=h_mask[:, None] & d_mask1[None, :], other=0.0 + ).to(tl.float32) + + weights = tl.load( + weights_ptr + pid_row * stride_wrow + offs_h * stride_wh, + mask=h_mask, + other=0.0, + ) + + q_T0 = tl.trans(q_vals0) + q_T1 = tl.trans(q_vals1) + + partial_dot = tl.dot(kv_f320, q_T0, out_dtype=tl.float32) + partial_dot = tl.dot(kv_f321, q_T1, acc=partial_dot, out_dtype=tl.float32) + + partial_dot = partial_dot * scale_f32[:, None] + partial_dot = tl.maximum(partial_dot, 0.0) + logit_accum += tl.sum(partial_dot * weights[None, :], axis=1) + + out_vals = tl.where(valid_mask, logit_accum, float("-inf")) + out_ptrs = logits_ptr + pid_row * stride_lrow + kv_global_pos * stride_lcol + out_mask = valid_mask & (kv_global_pos < max_model_len) + tl.store(out_ptrs, out_vals, mask=out_mask) + + +@triton.jit +def fill_neg_inf_kernel( + out_ptr, + n_elements, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < n_elements + tl.store(out_ptr + offs, float("-inf"), mask=mask) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + assert q.is_cuda and kv_cache.is_cuda and weights.is_cuda + assert context_lens.is_cuda and block_tables.is_cuda + + batch_size, next_n, heads, dim = q.size() + num_blocks, block_size, one, dim_plus_4 = kv_cache.size() + + assert one == 1 + assert dim_plus_4 == dim + 4 + assert weights.shape == (batch_size * next_n, heads) + assert kv_cache.dtype == torch.uint8 + assert context_lens.dtype == torch.int32 + assert block_tables.dtype == torch.int32 + + q_contig = q.contiguous() + kv_contig = kv_cache.contiguous() + weights_contig = weights.contiguous() + context_lens_contig = context_lens.contiguous() + block_tables_contig = block_tables.contiguous() + + total_rows = batch_size * next_n + + logits = torch.empty( + (total_rows, max_model_len), + device=q.device, + dtype=torch.float32, + ) + n_elements = total_rows * max_model_len + FILL_BLOCK = 1024 + fill_grid = (cdiv(n_elements, FILL_BLOCK),) + fill_neg_inf_kernel[fill_grid](logits, n_elements, BLOCK=FILL_BLOCK) + + max_context = block_tables_contig.shape[1] * block_size + + def grid(meta): + BLOCK_KV = meta["BLOCK_KV"] + num_kv_tiles = cdiv(max_context, BLOCK_KV) + return (total_rows, num_kv_tiles) + + fp8_paged_mqa_logits_kernel[grid]( + q_contig, + kv_contig, + weights_contig, + logits, + block_tables_contig, + context_lens_contig, + q_contig.stride(0), + q_contig.stride(1), + q_contig.stride(2), + q_contig.stride(3), + kv_contig.stride(0), + kv_contig.stride(1), + kv_contig.stride(2), + kv_contig.stride(3), + weights_contig.stride(0), + weights_contig.stride(1), + logits.stride(0), + logits.stride(1), + block_tables_contig.stride(0), + block_tables_contig.stride(1), + next_n, + heads, + dim, + block_size, + max_model_len, + dim_plus_4, + ) + + return logits diff --git a/tests/test_vllm_ops.py b/tests/test_vllm_ops.py index 6d58a61fbc..194ed47595 100644 --- a/tests/test_vllm_ops.py +++ b/tests/test_vllm_ops.py @@ -8,8 +8,30 @@ import flag_gems +from .accuracy_utils import gems_assert_close, to_reference from .conftest import QUICK_MODE +try: + import vllm + from vllm.platforms import current_platform + from vllm.utils.deep_gemm import calc_diff + from vllm.utils.deep_gemm import ( + fp8_paged_mqa_logits as fp8_paged_mqa_logits_deepgemm, + ) + from vllm.utils.deep_gemm import get_num_sms, get_paged_mqa_logits_metadata + from vllm.utils.import_utils import has_deep_gemm + + VLLM_AVAILABLE = True + DEEPGEMM_AVAILABLE = has_deep_gemm() +except ImportError: + VLLM_AVAILABLE = False + DEEPGEMM_AVAILABLE = False + get_num_sms = None + get_paged_mqa_logits_metadata = None + calc_diff = None + fp8_paged_mqa_logits_deepgemm = None + + random.seed(42) @@ -1017,3 +1039,281 @@ def test_get_paged_mqa_logits_metadata(batch_size, next_n, avg_ctx_len): res = flag_gems.get_paged_mqa_logits_metadata(context_lens_2d, 64, get_num_sms()) assert torch.equal(ref, res) + + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + x_fp8[:, block_size * head_dim :] = sf_bytes + + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + out = torch.empty( + (num_blocks, block_size, num_heads, head_dim + 4), + device=x.device, + dtype=torch.uint8, + ) + out[..., :head_dim] = x_scaled.view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) + + return out + + +def _build_mask(context_lens, batch_size, next_n, max_model_len, device): + positions = ( + torch.arange(max_model_len, device=device) + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = torch.arange(batch_size * next_n, device=device) // next_n + next_n_offset = torch.arange(batch_size * next_n, device=device) % next_n + return positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze( + 1 + ) + + +@pytest.mark.fp8 +@pytest.mark.paged_mqa_logits +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.skipif(not DEEPGEMM_AVAILABLE, reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +@pytest.mark.parametrize("clean_logits", [True, False]) +def test_accuracy_fp8_paged_mqa_logits(clean_logits: bool): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + batch_size, next_n = 4, 1 + heads, index_dim = 32, 128 + avg_kv = 2048 + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,), device=flag_gems.device + ).to(torch.int32) + max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + ref_out = fp8_paged_mqa_logits_deepgemm( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=clean_logits, + ) + ref_out = to_reference(ref_out) + + with flag_gems.use_gems(): + res_out = flag_gems.fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + mask = _build_mask( + context_lens, batch_size, next_n, max_model_len, flag_gems.device + ) + + ref_out = to_reference(ref_out) + res_out = to_reference(res_out) + mask = to_reference(mask) + + res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) + ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) + + gems_assert_close( + res_out_masked, + ref_out_masked, + res_out_masked.dtype, + equal_nan=True, + atol=5e-2, + reduce_dim=1, + ) + + diff = calc_diff(res_out_masked, ref_out_masked) + assert ( + diff < 1e-3 + ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)" + + +@pytest.mark.fp8 +@pytest.mark.paged_mqa_logits +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.skipif(not DEEPGEMM_AVAILABLE, reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +@pytest.mark.parametrize( + "batch_size, next_n", + [(1, 1), (2, 1), (4, 1), (8, 1), (16, 1), (32, 1), (2, 2), (4, 2)], +) +@pytest.mark.parametrize( + "heads, index_dim", + [(16, 64), (32, 128)], +) +def test_accuracy_fp8_paged_mqa_logits_param(batch_size, next_n, heads, index_dim): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + avg_kv = 2048 + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size,), + device=flag_gems.device, + dtype=torch.int32, + ) + max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + ref_out = fp8_paged_mqa_logits_deepgemm( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + ref_out = to_reference(ref_out) + + with flag_gems.use_gems(): + res_out = flag_gems.fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + mask = _build_mask( + context_lens, batch_size, next_n, max_model_len, flag_gems.device + ) + + ref_out = to_reference(ref_out) + res_out = to_reference(res_out) + mask = to_reference(mask) + + res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) + ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) + + gems_assert_close( + res_out_masked, + ref_out_masked, + res_out_masked.dtype, + equal_nan=True, + atol=5e-2, + reduce_dim=1, + ) + + diff = calc_diff(res_out_masked, ref_out_masked) + assert ( + diff < 1e-3 + ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)"