From c2e9334cdd6fe76eee766200c8e7bb2f81faf434 Mon Sep 17 00:00:00 2001 From: yy33min <2811552420@qq.com> Date: Tue, 24 Mar 2026 15:31:57 +0800 Subject: [PATCH 1/4] Added fp8_paged_mqa_logits_perf operator implementation and full testing --- benchmark/test_fp8_paged_mqa_logits_perf.py | 254 +++++++++++++++++ src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/fp8_paged_mqa_logits.py | 299 ++++++++++++++++++++ tests/test_fp8_paged_mqa_logits_ops.py | 268 ++++++++++++++++++ 5 files changed, 824 insertions(+) create mode 100644 benchmark/test_fp8_paged_mqa_logits_perf.py create mode 100644 src/flag_gems/ops/fp8_paged_mqa_logits.py create mode 100644 tests/test_fp8_paged_mqa_logits_ops.py diff --git a/benchmark/test_fp8_paged_mqa_logits_perf.py b/benchmark/test_fp8_paged_mqa_logits_perf.py new file mode 100644 index 0000000000..6d35f5794d --- /dev/null +++ b/benchmark/test_fp8_paged_mqa_logits_perf.py @@ -0,0 +1,254 @@ +import random +from itertools import product + +import pytest +import torch +from vllm.utils.deep_gemm import fp8_paged_mqa_logits as vllm_fp8_paged_mqa_logits +from vllm.utils.deep_gemm import get_num_sms, get_paged_mqa_logits_metadata +from vllm.utils.import_utils import has_deep_gemm + +import flag_gems +from benchmark.performance_utils import Benchmark +from flag_gems.ops.fp8_paged_mqa_logits import ( + fp8_paged_mqa_logits as gems_fp8_paged_mqa_logits, +) + +random.seed(42) + + +def is_vllm_available(): + try: + return True + except Exception: + return False + + +def is_hopper_available(): + if flag_gems.device != "cuda": + return False + major, minor = torch.cuda.get_device_capability() + return (major * 10 + minor) >= 90 + + +VLLM_AVAILABLE = is_vllm_available() +DEEPGEMM_AVAILABLE = has_deep_gemm() +HOPPER_AVAILABLE = is_hopper_available() + + +def kv_cache_cast_to_fp8_deepgemm(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + x_fp8[:, block_size * head_dim :] = sf_bytes + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + out = torch.empty( + (num_blocks, block_size, num_heads, head_dim + 4), + device=x.device, + dtype=torch.uint8, + ) + out[..., :head_dim] = x_scaled.view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) + return out + + +def _build_case( + batch_size, next_n, heads, head_dim, avg_kv, blocksize, q_dtype, max_model_len=4096 +): + num_blocks = max_model_len * 2 + + q = torch.randn( + (batch_size, next_n, heads, head_dim), + device=flag_gems.device, + dtype=q_dtype, + ) + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache = torch.randn( + (num_blocks, blocksize, 1, head_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + + weights = torch.randn( + (batch_size * next_n, heads), + device=flag_gems.device, + dtype=torch.float32, + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size,), + device=flag_gems.device, + dtype=torch.int32, + ) + + max_num_blocks_per_seq = ( + int(context_lens.max().item()) + blocksize - 1 + ) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), + device=flag_gems.device, + dtype=torch.int32, + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i, j] = block_idx_pool[counter] + counter += 1 + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8_deepgemm(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + return ( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + +class FP8PagedMQACompareBenchmark(Benchmark): + def __init__(self): + super().__init__( + "fp8_paged_mqa_logits_gems_vs_deepgemm", + self._vllm_wrapper, + [torch.bfloat16], + ) + self.set_gems(self._gems_wrapper) + + def set_shapes(self, shape_file_path=None): + self.shapes = [] + + def get_input_iter(self, _dtype): + compare_shapes = [ + (1, 1, 16, 64, 1024), + (2, 1, 32, 128, 2048), + (4, 1, 32, 128, 2048), + (2, 2, 32, 128, 2048), + (8, 1, 32, 128, 3072), + ] + q_dtypes = [torch.bfloat16, torch.float16] + blocksize = 64 + + for (bs, nn, h, d, avg_kv), q_dtype in product(compare_shapes, q_dtypes): + case = _build_case(bs, nn, h, d, avg_kv, blocksize, q_dtype) + ( + q_fp8, + kv_dg, + kv_tr, + weights, + context_lens, + block_tables, + max_model_len, + ) = case + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + yield ( + q_fp8, + kv_dg, + kv_tr, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ) + + @staticmethod + def _vllm_wrapper( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ): + return vllm_fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + + @staticmethod + def _gems_wrapper( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ): + return gems_fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + +@pytest.mark.skipif( + not ( + torch.cuda.is_available() + and VLLM_AVAILABLE + and DEEPGEMM_AVAILABLE + and HOPPER_AVAILABLE + ), + reason="requires CUDA + vLLM + DeepGEMM + Hopper", +) +@pytest.mark.performance +@pytest.mark.fp8_paged_mqa_logits +def test_perf_fp8_paged_mqa_logits_gems_vs_deepgemm(): + bench = FP8PagedMQACompareBenchmark() + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index d69ebd9586..e6779edcf3 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -178,6 +178,7 @@ def torch_ge(v): ("floor_divide.Scalar", floor_divide), ("floor_divide_.Scalar", floor_divide_), ("floor_divide_.Tensor", floor_divide_), + ("fp8_paged_mqa_logits", fp8_paged_mqa_logits), ("full", full), ("full_like", full_like), ("gather", gather), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index aa01acb67b..ba88d3e7df 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -104,6 +104,7 @@ fill_tensor_out, ) from flag_gems.ops.flip import flip +from flag_gems.ops.fp8_paged_mqa_logits import fp8_paged_mqa_logits from flag_gems.ops.full import full from flag_gems.ops.full_like import full_like from flag_gems.ops.gather import gather, gather_backward @@ -389,6 +390,7 @@ "flip", "floor_divide", "floor_divide_", + "fp8_paged_mqa_logits", "full", "full_like", "gather", diff --git a/src/flag_gems/ops/fp8_paged_mqa_logits.py b/src/flag_gems/ops/fp8_paged_mqa_logits.py new file mode 100644 index 0000000000..5eaa6cf8df --- /dev/null +++ b/src/flag_gems/ops/fp8_paged_mqa_logits.py @@ -0,0 +1,299 @@ +import torch +import triton +import triton.language as tl + + +def cdiv(x: int, y: int) -> int: + return (x + y - 1) // y + + +@triton.autotune( + configs=[ + # ---- BLOCK_D=128, NUM_D_TILES=1 (dim=128) ---- + triton.Config( + {"BLOCK_KV": 16, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 16, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 8}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 16, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + num_warps=8, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 8}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + num_warps=8, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + num_warps=4, + num_stages=2, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 8}, + num_warps=8, + num_stages=2, + ), + # ---- BLOCK_D=64, NUM_D_TILES=2 ---- + triton.Config( + {"BLOCK_KV": 16, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 4}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 32, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 4}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 32, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 8}, + num_warps=4, + num_stages=3, + ), + triton.Config( + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 4}, + num_warps=4, + num_stages=2, + ), + ], + key=["heads", "dim", "block_size"], +) +@triton.jit +def fp8_paged_mqa_logits_kernel( + q_ptr, + kv_ptr, + weights_ptr, + logits_ptr, + block_tables_ptr, + context_lens_ptr, + stride_qb, + stride_qn, + stride_qh, + stride_qd, + stride_kvblk, + stride_kvpos, + stride_kvone, + stride_kvbyte, + stride_wrow, + stride_wh, + stride_lrow, + stride_lcol, + stride_btb, + stride_bts, + next_n: tl.constexpr, + heads: tl.constexpr, + dim: tl.constexpr, + block_size: tl.constexpr, + max_model_len, + dim_plus_4: tl.constexpr, + BLOCK_KV: tl.constexpr, + BLOCK_D: tl.constexpr, + NUM_D_TILES: tl.constexpr, + HEADS_UNROLL: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_kv_tile = tl.program_id(1) + + batch_idx = pid_row // next_n + next_n_idx = pid_row % next_n + + context_len = tl.load(context_lens_ptr + batch_idx) + query_seq_pos = context_len - next_n + next_n_idx + + kv_start = pid_kv_tile * BLOCK_KV + if kv_start >= context_len: + return + + offs_kv = tl.arange(0, BLOCK_KV) + kv_global_pos = kv_start + offs_kv + + context_mask = kv_global_pos < context_len + causal_mask = kv_global_pos <= query_seq_pos + valid_mask = context_mask & causal_mask + + phys_block_idx = kv_global_pos // block_size + intra_block_pos = kv_global_pos % block_size + + phys_block_ids = tl.load( + block_tables_ptr + batch_idx * stride_btb + phys_block_idx * stride_bts, + mask=valid_mask, + other=0, + ) + + kv_base = phys_block_ids * stride_kvblk + intra_block_pos * stride_kvpos + + scale_addr = kv_base + dim * stride_kvbyte + b0 = tl.load(kv_ptr + scale_addr, mask=valid_mask, other=0).to(tl.uint32) + b1 = tl.load(kv_ptr + scale_addr + stride_kvbyte, mask=valid_mask, other=0).to( + tl.uint32 + ) + b2 = tl.load(kv_ptr + scale_addr + 2 * stride_kvbyte, mask=valid_mask, other=0).to( + tl.uint32 + ) + b3 = tl.load(kv_ptr + scale_addr + 3 * stride_kvbyte, mask=valid_mask, other=0).to( + tl.uint32 + ) + scale_u32 = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) + scale_f32 = scale_u32.to(tl.float32, bitcast=True) # [BLOCK_KV] + + logit_accum = tl.zeros([BLOCK_KV], dtype=tl.float32) + offs_d = tl.arange(0, BLOCK_D) + + for d_tile in tl.static_range(0, NUM_D_TILES): + d_offs = d_tile * BLOCK_D + offs_d + d_mask = d_offs < dim + + kv_byte_ptrs = kv_ptr + kv_base[:, None] + d_offs[None, :] * stride_kvbyte + load_mask = valid_mask[:, None] & d_mask[None, :] + kv_u8 = tl.load(kv_byte_ptrs, mask=load_mask, other=0) + kv_fp8 = kv_u8.to(tl.float8e4nv, bitcast=True) + kv_f32 = kv_fp8.to(tl.float32) + + kv_scaled = kv_f32 * scale_f32[:, None] + + q_base = ( + q_ptr + batch_idx * stride_qb + next_n_idx * stride_qn + d_offs * stride_qd + ) + + for h in range(heads): + q_vals = tl.load(q_base + h * stride_qh, mask=d_mask, other=0.0).to( + tl.float32 + ) + + w = tl.load(weights_ptr + pid_row * stride_wrow + h * stride_wh) + + partial_dot = tl.sum(kv_scaled * q_vals[None, :], axis=1) + + if NUM_D_TILES == 1: + dot_relu = tl.maximum(partial_dot, 0.0) + logit_accum += dot_relu * w + + if NUM_D_TILES > 1: + logit_accum2 = tl.zeros([BLOCK_KV], dtype=tl.float32) + + for h in range(heads): + w = tl.load(weights_ptr + pid_row * stride_wrow + h * stride_wh) + dot = tl.zeros([BLOCK_KV], dtype=tl.float32) + + for d_tile2 in tl.static_range(0, NUM_D_TILES): + d_offs2 = d_tile2 * BLOCK_D + offs_d + d_mask2 = d_offs2 < dim + + q_ptrs2 = ( + q_ptr + + batch_idx * stride_qb + + next_n_idx * stride_qn + + h * stride_qh + + d_offs2 * stride_qd + ) + q_vals2 = tl.load(q_ptrs2, mask=d_mask2, other=0.0).to(tl.float32) + + kv_byte_ptrs2 = ( + kv_ptr + kv_base[:, None] + d_offs2[None, :] * stride_kvbyte + ) + load_mask2 = valid_mask[:, None] & d_mask2[None, :] + kv_u82 = tl.load(kv_byte_ptrs2, mask=load_mask2, other=0) + kv_fp82 = kv_u82.to(tl.float8e4nv, bitcast=True) + kv_f322 = kv_fp82.to(tl.float32) + + dot += tl.sum(kv_f322 * q_vals2[None, :], axis=1) + + dot = dot * scale_f32 + dot = tl.maximum(dot, 0.0) + logit_accum2 += dot * w + + logit_accum = logit_accum2 + + out_vals = tl.where(valid_mask, logit_accum, float("-inf")) + out_ptrs = logits_ptr + pid_row * stride_lrow + kv_global_pos * stride_lcol + out_mask = valid_mask & (kv_global_pos < max_model_len) + tl.store(out_ptrs, out_vals, mask=out_mask) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + assert q.is_cuda and kv_cache.is_cuda and weights.is_cuda + assert context_lens.is_cuda and block_tables.is_cuda + + batch_size, next_n, heads, dim = q.size() + num_blocks, block_size, one, dim_plus_4 = kv_cache.size() + + assert one == 1, "KV cache must have num_heads=1 (MQA)" + assert dim_plus_4 == dim + 4, f"KV dim error: {dim_plus_4} != {dim}+4" + assert weights.shape == (batch_size * next_n, heads), "Weights shape mismatch" + assert kv_cache.dtype == torch.uint8, "KV cache must be uint8 (packed FP8+scale)" + assert context_lens.dtype == torch.int32, "Context lens must be int32" + assert block_tables.dtype == torch.int32, "Block tables must be int32" + + q_contig = q.contiguous() + kv_contig = kv_cache.contiguous() + weights_contig = weights.contiguous() + context_lens_contig = context_lens.contiguous() + block_tables_contig = block_tables.contiguous() + + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + + max_context = int(context_lens.max().item()) + + def grid(meta): + BLOCK_KV = meta["BLOCK_KV"] + num_kv_tiles = cdiv(max_context, BLOCK_KV) + return (batch_size * next_n, num_kv_tiles) + + fp8_paged_mqa_logits_kernel[grid]( + q_contig, + kv_contig, + weights_contig, + logits, + block_tables_contig, + context_lens_contig, + q_contig.stride(0), + q_contig.stride(1), + q_contig.stride(2), + q_contig.stride(3), + kv_contig.stride(0), + kv_contig.stride(1), + kv_contig.stride(2), + kv_contig.stride(3), + weights_contig.stride(0), + weights_contig.stride(1), + logits.stride(0), + logits.stride(1), + block_tables_contig.stride(0), + block_tables_contig.stride(1), + next_n, + heads, + dim, + block_size, + max_model_len, + dim_plus_4, + ) + + return logits diff --git a/tests/test_fp8_paged_mqa_logits_ops.py b/tests/test_fp8_paged_mqa_logits_ops.py new file mode 100644 index 0000000000..599a1f2feb --- /dev/null +++ b/tests/test_fp8_paged_mqa_logits_ops.py @@ -0,0 +1,268 @@ +import random + +import pytest +import torch +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import calc_diff +from vllm.utils.deep_gemm import fp8_paged_mqa_logits as fp8_paged_mqa_logits_deepgemm +from vllm.utils.deep_gemm import get_num_sms, get_paged_mqa_logits_metadata +from vllm.utils.import_utils import has_deep_gemm + +import flag_gems +from flag_gems.ops.fp8_paged_mqa_logits import fp8_paged_mqa_logits + +from .accuracy_utils import gems_assert_close, to_reference + + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + x_fp8[:, block_size * head_dim :] = sf_bytes + + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + out = torch.empty( + (num_blocks, block_size, num_heads, head_dim + 4), + device=x.device, + dtype=torch.uint8, + ) + out[..., :head_dim] = x_scaled.view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) # [num_blocks, block_size] + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) + + return out + + +def _build_mask(context_lens, batch_size, next_n, max_model_len, device): + positions = ( + torch.arange(max_model_len, device=device) + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = torch.arange(batch_size * next_n, device=device) // next_n + next_n_offset = torch.arange(batch_size * next_n, device=device) % next_n + return positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze( + 1 + ) + + +@pytest.mark.fp8 +@pytest.mark.paged_mqa_logits +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +@pytest.mark.parametrize("clean_logits", [True, False]) +def test_accuracy_fp8_paged_mqa_logits(clean_logits: bool): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + batch_size, next_n = 4, 1 + heads, index_dim = 32, 128 + avg_kv = 2048 + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,), device=flag_gems.device + ).to(torch.int32) + max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + ref_out = fp8_paged_mqa_logits_deepgemm( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=clean_logits, + ) + ref_out = to_reference(ref_out) + + with flag_gems.use_gems(): + res_out = fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + mask = _build_mask( + context_lens, batch_size, next_n, max_model_len, flag_gems.device + ) + res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) + ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) + + gems_assert_close( + res_out_masked, + ref_out_masked, + res_out_masked.dtype, + equal_nan=True, + atol=5e-2, + reduce_dim=1, + ) + + diff = calc_diff(res_out_masked, ref_out_masked) + assert diff < 1e-3, f"Triton 与 DeepGEMM 版本差异过大: {diff=} (要求 < 1e-3)" + + +@pytest.mark.fp8 +@pytest.mark.paged_mqa_logits +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +@pytest.mark.parametrize("batch_size, next_n", [(4, 1), (2, 2)]) +@pytest.mark.parametrize("heads, index_dim", [(32, 128)]) +def test_accuracy_fp8_paged_mqa_logits_param(batch_size, next_n, heads, index_dim): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + avg_kv = 2048 + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,), device=flag_gems.device + ).to(torch.int32) + max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + ref_out = fp8_paged_mqa_logits_deepgemm( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + ref_out = to_reference(ref_out) + + with flag_gems.use_gems(): + res_out = fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + mask = _build_mask( + context_lens, batch_size, next_n, max_model_len, flag_gems.device + ) + res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) + ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) + + gems_assert_close( + res_out_masked, + ref_out_masked, + res_out_masked.dtype, + equal_nan=True, + atol=5e-2, + reduce_dim=1, + ) + + diff = calc_diff(res_out_masked, ref_out_masked) + assert diff < 1e-3, f"Triton 与 DeepGEMM 版本差异过大: {diff=} (要求 < 1e-3)" From 517492741cd788d71f278f3cce3bf7aa824bd8c6 Mon Sep 17 00:00:00 2001 From: yy33min <2811552420@qq.com> Date: Tue, 31 Mar 2026 15:14:14 +0800 Subject: [PATCH 2/4] Optimize performance --- benchmark/test_fp8_paged_mqa_logits_perf.py | 2 +- src/flag_gems/ops/fp8_paged_mqa_logits.py | 238 +++++++++++--------- tests/test_fp8_paged_mqa_logits_ops.py | 40 +++- 3 files changed, 167 insertions(+), 113 deletions(-) diff --git a/benchmark/test_fp8_paged_mqa_logits_perf.py b/benchmark/test_fp8_paged_mqa_logits_perf.py index 6d35f5794d..3f01a01514 100644 --- a/benchmark/test_fp8_paged_mqa_logits_perf.py +++ b/benchmark/test_fp8_paged_mqa_logits_perf.py @@ -80,7 +80,7 @@ def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: def _build_case( batch_size, next_n, heads, head_dim, avg_kv, blocksize, q_dtype, max_model_len=4096 ): - num_blocks = max_model_len * 2 + num_blocks = max_model_len * 20 q = torch.randn( (batch_size, next_n, heads, head_dim), diff --git a/src/flag_gems/ops/fp8_paged_mqa_logits.py b/src/flag_gems/ops/fp8_paged_mqa_logits.py index 5eaa6cf8df..6c276c414a 100644 --- a/src/flag_gems/ops/fp8_paged_mqa_logits.py +++ b/src/flag_gems/ops/fp8_paged_mqa_logits.py @@ -9,68 +9,66 @@ def cdiv(x: int, y: int) -> int: @triton.autotune( configs=[ - # ---- BLOCK_D=128, NUM_D_TILES=1 (dim=128) ---- triton.Config( - {"BLOCK_KV": 16, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, - num_warps=4, - num_stages=3, - ), - triton.Config( - {"BLOCK_KV": 16, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 8}, - num_warps=4, - num_stages=3, + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 32}, + num_warps=8, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 16, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 16}, num_warps=8, - num_stages=3, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 8}, num_warps=4, - num_stages=3, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 8}, - num_warps=4, - num_stages=3, + {"BLOCK_KV": 128, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 32}, + num_warps=8, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + {"BLOCK_KV": 128, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 16}, num_warps=8, - num_stages=3, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 4}, + {"BLOCK_KV": 32, "BLOCK_D": 128, "NUM_D_TILES": 1, "BLOCK_H": 32}, num_warps=4, num_stages=2, ), triton.Config( - {"BLOCK_KV": 64, "BLOCK_D": 128, "NUM_D_TILES": 1, "HEADS_UNROLL": 8}, + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 2, "BLOCK_H": 16}, num_warps=8, num_stages=2, ), - # ---- BLOCK_D=64, NUM_D_TILES=2 ---- triton.Config( - {"BLOCK_KV": 16, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 4}, + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 2, "BLOCK_H": 8}, num_warps=4, - num_stages=3, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 32, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 4}, - num_warps=4, - num_stages=3, + {"BLOCK_KV": 128, "BLOCK_D": 64, "NUM_D_TILES": 2, "BLOCK_H": 16}, + num_warps=8, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 32, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 8}, + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 1, "BLOCK_H": 16}, num_warps=4, - num_stages=3, + num_stages=2, ), triton.Config( - {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 2, "HEADS_UNROLL": 4}, + {"BLOCK_KV": 64, "BLOCK_D": 64, "NUM_D_TILES": 1, "BLOCK_H": 8}, num_warps=4, num_stages=2, ), + triton.Config( + {"BLOCK_KV": 128, "BLOCK_D": 64, "NUM_D_TILES": 1, "BLOCK_H": 16}, + num_warps=8, + num_stages=2, + ), ], key=["heads", "dim", "block_size"], ) @@ -105,7 +103,7 @@ def fp8_paged_mqa_logits_kernel( BLOCK_KV: tl.constexpr, BLOCK_D: tl.constexpr, NUM_D_TILES: tl.constexpr, - HEADS_UNROLL: tl.constexpr, + BLOCK_H: tl.constexpr, ): pid_row = tl.program_id(0) pid_kv_tile = tl.program_id(1) @@ -118,6 +116,11 @@ def fp8_paged_mqa_logits_kernel( kv_start = pid_kv_tile * BLOCK_KV if kv_start >= context_len: + offs_kv = tl.arange(0, BLOCK_KV) + kv_pos = kv_start + offs_kv + out_mask = kv_pos < max_model_len + out_ptrs = logits_ptr + pid_row * stride_lrow + kv_pos * stride_lcol + tl.store(out_ptrs, float("-inf"), mask=out_mask) return offs_kv = tl.arange(0, BLOCK_KV) @@ -139,86 +142,94 @@ def fp8_paged_mqa_logits_kernel( kv_base = phys_block_ids * stride_kvblk + intra_block_pos * stride_kvpos scale_addr = kv_base + dim * stride_kvbyte - b0 = tl.load(kv_ptr + scale_addr, mask=valid_mask, other=0).to(tl.uint32) - b1 = tl.load(kv_ptr + scale_addr + stride_kvbyte, mask=valid_mask, other=0).to( - tl.uint32 - ) - b2 = tl.load(kv_ptr + scale_addr + 2 * stride_kvbyte, mask=valid_mask, other=0).to( - tl.uint32 - ) - b3 = tl.load(kv_ptr + scale_addr + 3 * stride_kvbyte, mask=valid_mask, other=0).to( - tl.uint32 - ) - scale_u32 = b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) - scale_f32 = scale_u32.to(tl.float32, bitcast=True) # [BLOCK_KV] + scale_ptr = (kv_ptr + scale_addr).to(tl.pointer_type(tl.uint32, 1), bitcast=True) + scale_u32 = tl.load(scale_ptr, mask=valid_mask, other=0) + scale_f32 = scale_u32.to(tl.float32, bitcast=True) logit_accum = tl.zeros([BLOCK_KV], dtype=tl.float32) offs_d = tl.arange(0, BLOCK_D) + q_base = q_ptr + batch_idx * stride_qb + next_n_idx * stride_qn - for d_tile in tl.static_range(0, NUM_D_TILES): - d_offs = d_tile * BLOCK_D + offs_d - d_mask = d_offs < dim + if NUM_D_TILES == 1: + d_mask = offs_d < dim - kv_byte_ptrs = kv_ptr + kv_base[:, None] + d_offs[None, :] * stride_kvbyte + kv_byte_ptrs = kv_ptr + kv_base[:, None] + offs_d[None, :] * stride_kvbyte load_mask = valid_mask[:, None] & d_mask[None, :] kv_u8 = tl.load(kv_byte_ptrs, mask=load_mask, other=0) kv_fp8 = kv_u8.to(tl.float8e4nv, bitcast=True) kv_f32 = kv_fp8.to(tl.float32) - kv_scaled = kv_f32 * scale_f32[:, None] - - q_base = ( - q_ptr + batch_idx * stride_qb + next_n_idx * stride_qn + d_offs * stride_qd - ) - - for h in range(heads): - q_vals = tl.load(q_base + h * stride_qh, mask=d_mask, other=0.0).to( - tl.float32 + for h_tile in tl.static_range(0, heads, BLOCK_H): + offs_h = h_tile + tl.arange(0, BLOCK_H) + h_mask = offs_h < heads + + q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd + q_vals = tl.load( + q_ptrs, mask=h_mask[:, None] & d_mask[None, :], other=0.0 + ).to(tl.float32) + weights = tl.load( + weights_ptr + pid_row * stride_wrow + offs_h * stride_wh, + mask=h_mask, + other=0.0, ) - w = tl.load(weights_ptr + pid_row * stride_wrow + h * stride_wh) - - partial_dot = tl.sum(kv_scaled * q_vals[None, :], axis=1) - - if NUM_D_TILES == 1: - dot_relu = tl.maximum(partial_dot, 0.0) - logit_accum += dot_relu * w - - if NUM_D_TILES > 1: - logit_accum2 = tl.zeros([BLOCK_KV], dtype=tl.float32) - - for h in range(heads): - w = tl.load(weights_ptr + pid_row * stride_wrow + h * stride_wh) - dot = tl.zeros([BLOCK_KV], dtype=tl.float32) - - for d_tile2 in tl.static_range(0, NUM_D_TILES): - d_offs2 = d_tile2 * BLOCK_D + offs_d - d_mask2 = d_offs2 < dim - - q_ptrs2 = ( - q_ptr - + batch_idx * stride_qb - + next_n_idx * stride_qn - + h * stride_qh - + d_offs2 * stride_qd - ) - q_vals2 = tl.load(q_ptrs2, mask=d_mask2, other=0.0).to(tl.float32) + q_tile = tl.trans(q_vals) + partial_dot = tl.dot(kv_f32, q_tile, out_dtype=tl.float32) + partial_dot = partial_dot * scale_f32[:, None] + partial_dot = tl.maximum(partial_dot, 0.0) + logit_accum += tl.sum(partial_dot * weights[None, :], axis=1) + + else: + d_offs0 = offs_d + d_mask0 = d_offs0 < dim + d_offs1 = BLOCK_D + offs_d + d_mask1 = d_offs1 < dim + + kv_byte_ptrs0 = kv_ptr + kv_base[:, None] + d_offs0[None, :] * stride_kvbyte + load_mask0 = valid_mask[:, None] & d_mask0[None, :] + kv_u80 = tl.load(kv_byte_ptrs0, mask=load_mask0, other=0) + kv_fp80 = kv_u80.to(tl.float8e4nv, bitcast=True) + kv_f320 = kv_fp80.to(tl.float32) + + kv_byte_ptrs1 = kv_ptr + kv_base[:, None] + d_offs1[None, :] * stride_kvbyte + load_mask1 = valid_mask[:, None] & d_mask1[None, :] + kv_u81 = tl.load(kv_byte_ptrs1, mask=load_mask1, other=0) + kv_fp81 = kv_u81.to(tl.float8e4nv, bitcast=True) + kv_f321 = kv_fp81.to(tl.float32) + + for h_tile in tl.static_range(0, heads, BLOCK_H): + offs_h = h_tile + tl.arange(0, BLOCK_H) + h_mask = offs_h < heads + + q_ptrs0 = ( + q_base + offs_h[:, None] * stride_qh + d_offs0[None, :] * stride_qd + ) + q_vals0 = tl.load( + q_ptrs0, mask=h_mask[:, None] & d_mask0[None, :], other=0.0 + ).to(tl.float32) - kv_byte_ptrs2 = ( - kv_ptr + kv_base[:, None] + d_offs2[None, :] * stride_kvbyte - ) - load_mask2 = valid_mask[:, None] & d_mask2[None, :] - kv_u82 = tl.load(kv_byte_ptrs2, mask=load_mask2, other=0) - kv_fp82 = kv_u82.to(tl.float8e4nv, bitcast=True) - kv_f322 = kv_fp82.to(tl.float32) + q_ptrs1 = ( + q_base + offs_h[:, None] * stride_qh + d_offs1[None, :] * stride_qd + ) + q_vals1 = tl.load( + q_ptrs1, mask=h_mask[:, None] & d_mask1[None, :], other=0.0 + ).to(tl.float32) + + weights = tl.load( + weights_ptr + pid_row * stride_wrow + offs_h * stride_wh, + mask=h_mask, + other=0.0, + ) - dot += tl.sum(kv_f322 * q_vals2[None, :], axis=1) + q_T0 = tl.trans(q_vals0) + q_T1 = tl.trans(q_vals1) - dot = dot * scale_f32 - dot = tl.maximum(dot, 0.0) - logit_accum2 += dot * w + partial_dot = tl.dot(kv_f320, q_T0, out_dtype=tl.float32) + partial_dot = tl.dot(kv_f321, q_T1, acc=partial_dot, out_dtype=tl.float32) - logit_accum = logit_accum2 + partial_dot = partial_dot * scale_f32[:, None] + partial_dot = tl.maximum(partial_dot, 0.0) + logit_accum += tl.sum(partial_dot * weights[None, :], axis=1) out_vals = tl.where(valid_mask, logit_accum, float("-inf")) out_ptrs = logits_ptr + pid_row * stride_lrow + kv_global_pos * stride_lcol @@ -226,6 +237,18 @@ def fp8_paged_mqa_logits_kernel( tl.store(out_ptrs, out_vals, mask=out_mask) +@triton.jit +def fill_neg_inf_kernel( + out_ptr, + n_elements, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < n_elements + tl.store(out_ptr + offs, float("-inf"), mask=mask) + + def fp8_paged_mqa_logits( q: torch.Tensor, kv_cache: torch.Tensor, @@ -240,12 +263,12 @@ def fp8_paged_mqa_logits( batch_size, next_n, heads, dim = q.size() num_blocks, block_size, one, dim_plus_4 = kv_cache.size() - assert one == 1, "KV cache must have num_heads=1 (MQA)" - assert dim_plus_4 == dim + 4, f"KV dim error: {dim_plus_4} != {dim}+4" - assert weights.shape == (batch_size * next_n, heads), "Weights shape mismatch" - assert kv_cache.dtype == torch.uint8, "KV cache must be uint8 (packed FP8+scale)" - assert context_lens.dtype == torch.int32, "Context lens must be int32" - assert block_tables.dtype == torch.int32, "Block tables must be int32" + assert one == 1 + assert dim_plus_4 == dim + 4 + assert weights.shape == (batch_size * next_n, heads) + assert kv_cache.dtype == torch.uint8 + assert context_lens.dtype == torch.int32 + assert block_tables.dtype == torch.int32 q_contig = q.contiguous() kv_contig = kv_cache.contiguous() @@ -253,19 +276,24 @@ def fp8_paged_mqa_logits( context_lens_contig = context_lens.contiguous() block_tables_contig = block_tables.contiguous() - logits = torch.full( - (batch_size * next_n, max_model_len), - float("-inf"), + total_rows = batch_size * next_n + + logits = torch.empty( + (total_rows, max_model_len), device=q.device, dtype=torch.float32, ) + n_elements = total_rows * max_model_len + FILL_BLOCK = 1024 + fill_grid = (cdiv(n_elements, FILL_BLOCK),) + fill_neg_inf_kernel[fill_grid](logits, n_elements, BLOCK=FILL_BLOCK) - max_context = int(context_lens.max().item()) + max_context = block_tables_contig.shape[1] * block_size def grid(meta): BLOCK_KV = meta["BLOCK_KV"] num_kv_tiles = cdiv(max_context, BLOCK_KV) - return (batch_size * next_n, num_kv_tiles) + return (total_rows, num_kv_tiles) fp8_paged_mqa_logits_kernel[grid]( q_contig, diff --git a/tests/test_fp8_paged_mqa_logits_ops.py b/tests/test_fp8_paged_mqa_logits_ops.py index 599a1f2feb..1f3e6bea56 100644 --- a/tests/test_fp8_paged_mqa_logits_ops.py +++ b/tests/test_fp8_paged_mqa_logits_ops.py @@ -53,7 +53,7 @@ def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: ) out[..., :head_dim] = x_scaled.view(torch.uint8) - sf_scaled = sf.squeeze(-1).squeeze(-1) # [num_blocks, block_size] + sf_scaled = sf.squeeze(-1).squeeze(-1) sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) @@ -168,7 +168,9 @@ def test_accuracy_fp8_paged_mqa_logits(clean_logits: bool): ) diff = calc_diff(res_out_masked, ref_out_masked) - assert diff < 1e-3, f"Triton 与 DeepGEMM 版本差异过大: {diff=} (要求 < 1e-3)" + assert ( + diff < 1e-3 + ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)" @pytest.mark.fp8 @@ -178,8 +180,26 @@ def test_accuracy_fp8_paged_mqa_logits(clean_logits: bool): @pytest.mark.skipif( not current_platform.has_device_capability(90), reason="SM90 and SM100 only" ) -@pytest.mark.parametrize("batch_size, next_n", [(4, 1), (2, 2)]) -@pytest.mark.parametrize("heads, index_dim", [(32, 128)]) +@pytest.mark.parametrize( + "batch_size, next_n", + [ + (1, 1), + (2, 1), + (4, 1), + (8, 1), + (16, 1), + (32, 1), + (2, 2), + (4, 2), + ], +) +@pytest.mark.parametrize( + "heads, index_dim", + [ + (16, 64), + (32, 128), + ], +) def test_accuracy_fp8_paged_mqa_logits_param(batch_size, next_n, heads, index_dim): torch.manual_seed(0) random.seed(0) @@ -203,8 +223,12 @@ def test_accuracy_fp8_paged_mqa_logits_param(batch_size, next_n, heads, index_di ) context_lens = torch.randint( - int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,), device=flag_gems.device - ).to(torch.int32) + int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size,), + device=flag_gems.device, + dtype=torch.int32, + ) max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize block_tables = torch.zeros( (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 @@ -265,4 +289,6 @@ def test_accuracy_fp8_paged_mqa_logits_param(batch_size, next_n, heads, index_di ) diff = calc_diff(res_out_masked, ref_out_masked) - assert diff < 1e-3, f"Triton 与 DeepGEMM 版本差异过大: {diff=} (要求 < 1e-3)" + assert ( + diff < 1e-3 + ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)" From 6e8de9a026f13f82508fa62f9107df7b901b398d Mon Sep 17 00:00:00 2001 From: yy33min <2811552420@qq.com> Date: Thu, 2 Apr 2026 15:05:27 +0800 Subject: [PATCH 3/4] style: fix isort import order for fp8_paged_mqa_logits --- src/flag_gems/__init__.py | 2 +- src/flag_gems/ops/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 7cb13c79ca..11d2f7d0bd 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -189,9 +189,9 @@ def torch_ge(v): ("floor_divide.Scalar", floor_divide), ("floor_divide_.Scalar", floor_divide_), ("floor_divide_.Tensor", floor_divide_), - ("fp8_paged_mqa_logits", fp8_paged_mqa_logits), ("fmin", fmin), ("fmin.out", fmin_out), + ("fp8_paged_mqa_logits", fp8_paged_mqa_logits), ("full", full), ("full_like", full_like), ("gather", gather), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index bb309f6b96..e988f8b779 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -113,9 +113,9 @@ fill_tensor_out, ) from flag_gems.ops.flip import flip -from flag_gems.ops.fp8_paged_mqa_logits import fp8_paged_mqa_logits from flag_gems.ops.floor_ import floor_ from flag_gems.ops.fmin import fmin, fmin_out +from flag_gems.ops.fp8_paged_mqa_logits import fp8_paged_mqa_logits from flag_gems.ops.full import full from flag_gems.ops.full_like import full_like from flag_gems.ops.gather import gather, gather_backward @@ -432,9 +432,9 @@ "floor_", "floor_divide", "floor_divide_", - "fp8_paged_mqa_logits", "fmin", "fmin_out", + "fp8_paged_mqa_logits", "full", "full_like", "gather", From c3ef3cfa3b63cb93c927ba666ebff5e6447256f3 Mon Sep 17 00:00:00 2001 From: yy33min <2811552420@qq.com> Date: Thu, 2 Apr 2026 20:35:45 +0800 Subject: [PATCH 4/4] test: integrate fp8_paged_mqa_logits vllm tests --- benchmark/test_fp8_paged_mqa_logits_perf.py | 254 ----------------- benchmark/test_vllm_perf.py | 234 +++++++++++++++ tests/test_fp8_paged_mqa_logits_ops.py | 294 ------------------- tests/test_vllm_ops.py | 300 ++++++++++++++++++++ 4 files changed, 534 insertions(+), 548 deletions(-) delete mode 100644 benchmark/test_fp8_paged_mqa_logits_perf.py delete mode 100644 tests/test_fp8_paged_mqa_logits_ops.py diff --git a/benchmark/test_fp8_paged_mqa_logits_perf.py b/benchmark/test_fp8_paged_mqa_logits_perf.py deleted file mode 100644 index 3f01a01514..0000000000 --- a/benchmark/test_fp8_paged_mqa_logits_perf.py +++ /dev/null @@ -1,254 +0,0 @@ -import random -from itertools import product - -import pytest -import torch -from vllm.utils.deep_gemm import fp8_paged_mqa_logits as vllm_fp8_paged_mqa_logits -from vllm.utils.deep_gemm import get_num_sms, get_paged_mqa_logits_metadata -from vllm.utils.import_utils import has_deep_gemm - -import flag_gems -from benchmark.performance_utils import Benchmark -from flag_gems.ops.fp8_paged_mqa_logits import ( - fp8_paged_mqa_logits as gems_fp8_paged_mqa_logits, -) - -random.seed(42) - - -def is_vllm_available(): - try: - return True - except Exception: - return False - - -def is_hopper_available(): - if flag_gems.device != "cuda": - return False - major, minor = torch.cuda.get_device_capability() - return (major * 10 + minor) >= 90 - - -VLLM_AVAILABLE = is_vllm_available() -DEEPGEMM_AVAILABLE = has_deep_gemm() -HOPPER_AVAILABLE = is_hopper_available() - - -def kv_cache_cast_to_fp8_deepgemm(x: torch.Tensor) -> torch.Tensor: - num_blocks, block_size, num_heads, head_dim = x.shape - assert num_heads == 1 - x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - - x_fp8 = torch.empty( - (num_blocks, block_size * (head_dim + 4)), - device=x.device, - dtype=torch.uint8, - ) - x_fp8[:, : block_size * head_dim] = x_scaled.view( - num_blocks, block_size * head_dim - ).view(torch.uint8) - - sf_scaled = sf.squeeze(-1).squeeze(-1) - sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) - x_fp8[:, block_size * head_dim :] = sf_bytes - return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) - - -def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: - num_blocks, block_size, num_heads, head_dim = x.shape - assert num_heads == 1 - x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - - out = torch.empty( - (num_blocks, block_size, num_heads, head_dim + 4), - device=x.device, - dtype=torch.uint8, - ) - out[..., :head_dim] = x_scaled.view(torch.uint8) - - sf_scaled = sf.squeeze(-1).squeeze(-1) - sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) - out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) - return out - - -def _build_case( - batch_size, next_n, heads, head_dim, avg_kv, blocksize, q_dtype, max_model_len=4096 -): - num_blocks = max_model_len * 20 - - q = torch.randn( - (batch_size, next_n, heads, head_dim), - device=flag_gems.device, - dtype=q_dtype, - ) - q_fp8 = q.to(torch.float8_e4m3fn) - - kv_cache = torch.randn( - (num_blocks, blocksize, 1, head_dim), - device=flag_gems.device, - dtype=torch.bfloat16, - ) - - weights = torch.randn( - (batch_size * next_n, heads), - device=flag_gems.device, - dtype=torch.float32, - ) - - context_lens = torch.randint( - int(0.8 * avg_kv), - int(1.2 * avg_kv), - (batch_size,), - device=flag_gems.device, - dtype=torch.int32, - ) - - max_num_blocks_per_seq = ( - int(context_lens.max().item()) + blocksize - 1 - ) // blocksize - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), - device=flag_gems.device, - dtype=torch.int32, - ) - - counter = 0 - block_idx_pool = list(range(num_blocks)) - random.shuffle(block_idx_pool) - for i in range(batch_size): - ctx_len = int(context_lens[i].item()) - for j in range((ctx_len + blocksize - 1) // blocksize): - block_tables[i, j] = block_idx_pool[counter] - counter += 1 - - kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8_deepgemm(kv_cache) - kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) - - return ( - q_fp8, - kv_cache_fp8_deepgemm, - kv_cache_fp8_triton, - weights, - context_lens, - block_tables, - max_model_len, - ) - - -class FP8PagedMQACompareBenchmark(Benchmark): - def __init__(self): - super().__init__( - "fp8_paged_mqa_logits_gems_vs_deepgemm", - self._vllm_wrapper, - [torch.bfloat16], - ) - self.set_gems(self._gems_wrapper) - - def set_shapes(self, shape_file_path=None): - self.shapes = [] - - def get_input_iter(self, _dtype): - compare_shapes = [ - (1, 1, 16, 64, 1024), - (2, 1, 32, 128, 2048), - (4, 1, 32, 128, 2048), - (2, 2, 32, 128, 2048), - (8, 1, 32, 128, 3072), - ] - q_dtypes = [torch.bfloat16, torch.float16] - blocksize = 64 - - for (bs, nn, h, d, avg_kv), q_dtype in product(compare_shapes, q_dtypes): - case = _build_case(bs, nn, h, d, avg_kv, blocksize, q_dtype) - ( - q_fp8, - kv_dg, - kv_tr, - weights, - context_lens, - block_tables, - max_model_len, - ) = case - schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, get_num_sms() - ) - yield ( - q_fp8, - kv_dg, - kv_tr, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - q_dtype, - blocksize, - ) - - @staticmethod - def _vllm_wrapper( - q_fp8, - kv_cache_fp8_deepgemm, - kv_cache_fp8_triton, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - q_dtype, - blocksize, - ): - return vllm_fp8_paged_mqa_logits( - q_fp8, - kv_cache_fp8_deepgemm, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True, - ) - - @staticmethod - def _gems_wrapper( - q_fp8, - kv_cache_fp8_deepgemm, - kv_cache_fp8_triton, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - q_dtype, - blocksize, - ): - return gems_fp8_paged_mqa_logits( - q_fp8, - kv_cache_fp8_triton, - weights, - context_lens, - block_tables, - max_model_len, - ) - - -@pytest.mark.skipif( - not ( - torch.cuda.is_available() - and VLLM_AVAILABLE - and DEEPGEMM_AVAILABLE - and HOPPER_AVAILABLE - ), - reason="requires CUDA + vLLM + DeepGEMM + Hopper", -) -@pytest.mark.performance -@pytest.mark.fp8_paged_mqa_logits -def test_perf_fp8_paged_mqa_logits_gems_vs_deepgemm(): - bench = FP8PagedMQACompareBenchmark() - bench.run() diff --git a/benchmark/test_vllm_perf.py b/benchmark/test_vllm_perf.py index 888db878fc..4bd0ee024a 100644 --- a/benchmark/test_vllm_perf.py +++ b/benchmark/test_vllm_perf.py @@ -18,6 +18,21 @@ def is_vllm_available(): return False +try: + from vllm.utils.deep_gemm import fp8_paged_mqa_logits as vllm_fp8_paged_mqa_logits + from vllm.utils.deep_gemm import ( + get_num_sms, + get_paged_mqa_logits_metadata, + has_deep_gemm, + ) + + DEEPGEMM_AVAILABLE = has_deep_gemm() +except ImportError: + DEEPGEMM_AVAILABLE = False + vllm_fp8_paged_mqa_logits = None + get_num_sms = None + get_paged_mqa_logits_metadata = None + VLLM_AVAILABLE = is_vllm_available() @@ -1058,3 +1073,222 @@ def test_get_paged_mqa_logits_metadata_benchmark(): ) bench.set_gems(flag_gems.get_paged_mqa_logits_metadata) bench.run() + + +def kv_cache_cast_to_fp8_deepgemm(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + x_fp8[:, block_size * head_dim :] = sf_bytes + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + out = torch.empty( + (num_blocks, block_size, num_heads, head_dim + 4), + device=x.device, + dtype=torch.uint8, + ) + out[..., :head_dim] = x_scaled.view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) + return out + + +def _build_case( + batch_size, next_n, heads, head_dim, avg_kv, blocksize, q_dtype, max_model_len=4096 +): + num_blocks = max_model_len * 20 + + q = torch.randn( + (batch_size, next_n, heads, head_dim), + device=flag_gems.device, + dtype=q_dtype, + ) + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache = torch.randn( + (num_blocks, blocksize, 1, head_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + + weights = torch.randn( + (batch_size * next_n, heads), + device=flag_gems.device, + dtype=torch.float32, + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size,), + device=flag_gems.device, + dtype=torch.int32, + ) + + max_num_blocks_per_seq = ( + int(context_lens.max().item()) + blocksize - 1 + ) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), + device=flag_gems.device, + dtype=torch.int32, + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i, j] = block_idx_pool[counter] + counter += 1 + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8_deepgemm(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + return ( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + +class FP8PagedMQACompareBenchmark(Benchmark): + def __init__(self): + super().__init__( + "fp8_paged_mqa_logits_gems_vs_deepgemm", + self._vllm_wrapper, + [torch.bfloat16], + ) + self.set_gems(self._gems_wrapper) + + def set_shapes(self, shape_file_path=None): + self.shapes = [] + + def get_input_iter(self, _dtype): + compare_shapes = [ + (1, 1, 16, 64, 1024), + (2, 1, 32, 128, 2048), + (4, 1, 32, 128, 2048), + (2, 2, 32, 128, 2048), + (8, 1, 32, 128, 3072), + ] + q_dtypes = [torch.bfloat16, torch.float16] + blocksize = 64 + + for (bs, nn, h, d, avg_kv), q_dtype in product(compare_shapes, q_dtypes): + case = _build_case(bs, nn, h, d, avg_kv, blocksize, q_dtype) + ( + q_fp8, + kv_dg, + kv_tr, + weights, + context_lens, + block_tables, + max_model_len, + ) = case + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + yield ( + q_fp8, + kv_dg, + kv_tr, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ) + + @staticmethod + def _vllm_wrapper( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ): + return vllm_fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + + @staticmethod + def _gems_wrapper( + q_fp8, + kv_cache_fp8_deepgemm, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + q_dtype, + blocksize, + ): + return flag_gems.fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + +@pytest.mark.skipif( + not ( + torch.cuda.is_available() + and VLLM_AVAILABLE + and DEEPGEMM_AVAILABLE + and CUDA_AVAILABLE + ), + reason="requires CUDA + vLLM + Hopper", +) +@pytest.mark.performance +@pytest.mark.fp8_paged_mqa_logits +def test_perf_fp8_paged_mqa_logits_gems_vs_deepgemm(): + bench = FP8PagedMQACompareBenchmark() + bench.run() diff --git a/tests/test_fp8_paged_mqa_logits_ops.py b/tests/test_fp8_paged_mqa_logits_ops.py deleted file mode 100644 index 1f3e6bea56..0000000000 --- a/tests/test_fp8_paged_mqa_logits_ops.py +++ /dev/null @@ -1,294 +0,0 @@ -import random - -import pytest -import torch -from vllm.platforms import current_platform -from vllm.utils.deep_gemm import calc_diff -from vllm.utils.deep_gemm import fp8_paged_mqa_logits as fp8_paged_mqa_logits_deepgemm -from vllm.utils.deep_gemm import get_num_sms, get_paged_mqa_logits_metadata -from vllm.utils.import_utils import has_deep_gemm - -import flag_gems -from flag_gems.ops.fp8_paged_mqa_logits import fp8_paged_mqa_logits - -from .accuracy_utils import gems_assert_close, to_reference - - -def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: - num_blocks, block_size, num_heads, head_dim = x.shape - assert num_heads == 1 - - x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - - x_fp8 = torch.empty( - (num_blocks, block_size * (head_dim + 4)), - device=x.device, - dtype=torch.uint8, - ) - x_fp8[:, : block_size * head_dim] = x_scaled.view( - num_blocks, block_size * head_dim - ).view(torch.uint8) - - sf_scaled = sf.squeeze(-1).squeeze(-1) - sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) - x_fp8[:, block_size * head_dim :] = sf_bytes - - return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) - - -def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: - num_blocks, block_size, num_heads, head_dim = x.shape - assert num_heads == 1 - - x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - - out = torch.empty( - (num_blocks, block_size, num_heads, head_dim + 4), - device=x.device, - dtype=torch.uint8, - ) - out[..., :head_dim] = x_scaled.view(torch.uint8) - - sf_scaled = sf.squeeze(-1).squeeze(-1) - sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) - out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) - - return out - - -def _build_mask(context_lens, batch_size, next_n, max_model_len, device): - positions = ( - torch.arange(max_model_len, device=device) - .unsqueeze(0) - .expand(batch_size * next_n, -1) - ) - row_indices = torch.arange(batch_size * next_n, device=device) // next_n - next_n_offset = torch.arange(batch_size * next_n, device=device) % next_n - return positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze( - 1 - ) - - -@pytest.mark.fp8 -@pytest.mark.paged_mqa_logits -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") -@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" -) -@pytest.mark.parametrize("clean_logits", [True, False]) -def test_accuracy_fp8_paged_mqa_logits(clean_logits: bool): - torch.manual_seed(0) - random.seed(0) - - max_model_len = 4096 - batch_size, next_n = 4, 1 - heads, index_dim = 32, 128 - avg_kv = 2048 - num_blocks, blocksize = max_model_len * 2, 64 - - q = torch.randn( - (batch_size, next_n, heads, index_dim), - device=flag_gems.device, - dtype=torch.bfloat16, - ) - kv_cache = torch.randn( - (num_blocks, blocksize, 1, index_dim), - device=flag_gems.device, - dtype=torch.bfloat16, - ) - weights = torch.randn( - (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 - ) - - context_lens = torch.randint( - int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,), device=flag_gems.device - ).to(torch.int32) - max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 - ) - - counter = 0 - block_idx_pool = list(range(num_blocks)) - random.shuffle(block_idx_pool) - for i in range(batch_size): - ctx_len = int(context_lens[i].item()) - for j in range((ctx_len + blocksize - 1) // blocksize): - block_tables[i][j] = block_idx_pool[counter] - counter += 1 - - q_fp8 = q.to(torch.float8_e4m3fn) - - kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) - kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) - - schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, get_num_sms() - ) - ref_out = fp8_paged_mqa_logits_deepgemm( - q_fp8, - kv_cache_fp8_deepgemm, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=clean_logits, - ) - ref_out = to_reference(ref_out) - - with flag_gems.use_gems(): - res_out = fp8_paged_mqa_logits( - q_fp8, - kv_cache_fp8_triton, - weights, - context_lens, - block_tables, - max_model_len, - ) - - mask = _build_mask( - context_lens, batch_size, next_n, max_model_len, flag_gems.device - ) - res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) - ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) - - gems_assert_close( - res_out_masked, - ref_out_masked, - res_out_masked.dtype, - equal_nan=True, - atol=5e-2, - reduce_dim=1, - ) - - diff = calc_diff(res_out_masked, ref_out_masked) - assert ( - diff < 1e-3 - ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)" - - -@pytest.mark.fp8 -@pytest.mark.paged_mqa_logits -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") -@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" -) -@pytest.mark.parametrize( - "batch_size, next_n", - [ - (1, 1), - (2, 1), - (4, 1), - (8, 1), - (16, 1), - (32, 1), - (2, 2), - (4, 2), - ], -) -@pytest.mark.parametrize( - "heads, index_dim", - [ - (16, 64), - (32, 128), - ], -) -def test_accuracy_fp8_paged_mqa_logits_param(batch_size, next_n, heads, index_dim): - torch.manual_seed(0) - random.seed(0) - - max_model_len = 4096 - avg_kv = 2048 - num_blocks, blocksize = max_model_len * 2, 64 - - q = torch.randn( - (batch_size, next_n, heads, index_dim), - device=flag_gems.device, - dtype=torch.bfloat16, - ) - kv_cache = torch.randn( - (num_blocks, blocksize, 1, index_dim), - device=flag_gems.device, - dtype=torch.bfloat16, - ) - weights = torch.randn( - (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 - ) - - context_lens = torch.randint( - int(0.8 * avg_kv), - int(1.2 * avg_kv), - (batch_size,), - device=flag_gems.device, - dtype=torch.int32, - ) - max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize - block_tables = torch.zeros( - (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 - ) - - counter = 0 - block_idx_pool = list(range(num_blocks)) - random.shuffle(block_idx_pool) - for i in range(batch_size): - ctx_len = int(context_lens[i].item()) - for j in range((ctx_len + blocksize - 1) // blocksize): - block_tables[i][j] = block_idx_pool[counter] - counter += 1 - - q_fp8 = q.to(torch.float8_e4m3fn) - - kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) - kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) - - schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, get_num_sms() - ) - ref_out = fp8_paged_mqa_logits_deepgemm( - q_fp8, - kv_cache_fp8_deepgemm, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True, - ) - ref_out = to_reference(ref_out) - - with flag_gems.use_gems(): - res_out = fp8_paged_mqa_logits( - q_fp8, - kv_cache_fp8_triton, - weights, - context_lens, - block_tables, - max_model_len, - ) - - mask = _build_mask( - context_lens, batch_size, next_n, max_model_len, flag_gems.device - ) - res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) - ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) - - gems_assert_close( - res_out_masked, - ref_out_masked, - res_out_masked.dtype, - equal_nan=True, - atol=5e-2, - reduce_dim=1, - ) - - diff = calc_diff(res_out_masked, ref_out_masked) - assert ( - diff < 1e-3 - ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)" diff --git a/tests/test_vllm_ops.py b/tests/test_vllm_ops.py index 6d58a61fbc..194ed47595 100644 --- a/tests/test_vllm_ops.py +++ b/tests/test_vllm_ops.py @@ -8,8 +8,30 @@ import flag_gems +from .accuracy_utils import gems_assert_close, to_reference from .conftest import QUICK_MODE +try: + import vllm + from vllm.platforms import current_platform + from vllm.utils.deep_gemm import calc_diff + from vllm.utils.deep_gemm import ( + fp8_paged_mqa_logits as fp8_paged_mqa_logits_deepgemm, + ) + from vllm.utils.deep_gemm import get_num_sms, get_paged_mqa_logits_metadata + from vllm.utils.import_utils import has_deep_gemm + + VLLM_AVAILABLE = True + DEEPGEMM_AVAILABLE = has_deep_gemm() +except ImportError: + VLLM_AVAILABLE = False + DEEPGEMM_AVAILABLE = False + get_num_sms = None + get_paged_mqa_logits_metadata = None + calc_diff = None + fp8_paged_mqa_logits_deepgemm = None + + random.seed(42) @@ -1017,3 +1039,281 @@ def test_get_paged_mqa_logits_metadata(batch_size, next_n, avg_ctx_len): res = flag_gems.get_paged_mqa_logits_metadata(context_lens_2d, 64, get_num_sms()) assert torch.equal(ref, res) + + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + x_fp8[:, block_size * head_dim :] = sf_bytes + + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def kv_cache_cast_to_fp8_triton(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + + out = torch.empty( + (num_blocks, block_size, num_heads, head_dim + 4), + device=x.device, + dtype=torch.uint8, + ) + out[..., :head_dim] = x_scaled.view(torch.uint8) + + sf_scaled = sf.squeeze(-1).squeeze(-1) + sf_bytes = sf_scaled.view(torch.int32).view(torch.uint8) + out[..., head_dim:] = sf_bytes.view(num_blocks, block_size, num_heads, 4) + + return out + + +def _build_mask(context_lens, batch_size, next_n, max_model_len, device): + positions = ( + torch.arange(max_model_len, device=device) + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = torch.arange(batch_size * next_n, device=device) // next_n + next_n_offset = torch.arange(batch_size * next_n, device=device) % next_n + return positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze( + 1 + ) + + +@pytest.mark.fp8 +@pytest.mark.paged_mqa_logits +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.skipif(not DEEPGEMM_AVAILABLE, reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +@pytest.mark.parametrize("clean_logits", [True, False]) +def test_accuracy_fp8_paged_mqa_logits(clean_logits: bool): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + batch_size, next_n = 4, 1 + heads, index_dim = 32, 128 + avg_kv = 2048 + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,), device=flag_gems.device + ).to(torch.int32) + max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + ref_out = fp8_paged_mqa_logits_deepgemm( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=clean_logits, + ) + ref_out = to_reference(ref_out) + + with flag_gems.use_gems(): + res_out = flag_gems.fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + mask = _build_mask( + context_lens, batch_size, next_n, max_model_len, flag_gems.device + ) + + ref_out = to_reference(ref_out) + res_out = to_reference(res_out) + mask = to_reference(mask) + + res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) + ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) + + gems_assert_close( + res_out_masked, + ref_out_masked, + res_out_masked.dtype, + equal_nan=True, + atol=5e-2, + reduce_dim=1, + ) + + diff = calc_diff(res_out_masked, ref_out_masked) + assert ( + diff < 1e-3 + ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)" + + +@pytest.mark.fp8 +@pytest.mark.paged_mqa_logits +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.skipif(not DEEPGEMM_AVAILABLE, reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +@pytest.mark.parametrize( + "batch_size, next_n", + [(1, 1), (2, 1), (4, 1), (8, 1), (16, 1), (32, 1), (2, 2), (4, 2)], +) +@pytest.mark.parametrize( + "heads, index_dim", + [(16, 64), (32, 128)], +) +def test_accuracy_fp8_paged_mqa_logits_param(batch_size, next_n, heads, index_dim): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + avg_kv = 2048 + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device=flag_gems.device, + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), device=flag_gems.device, dtype=torch.float32 + ) + + context_lens = torch.randint( + int(0.8 * avg_kv), + int(1.2 * avg_kv), + (batch_size,), + device=flag_gems.device, + dtype=torch.int32, + ) + max_num_blocks_per_seq = (context_lens.max().item() + blocksize - 1) // blocksize + block_tables = torch.zeros( + (batch_size, max_num_blocks_per_seq), device=flag_gems.device, dtype=torch.int32 + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + + kv_cache_fp8_deepgemm = kv_cache_cast_to_fp8(kv_cache) + kv_cache_fp8_triton = kv_cache_cast_to_fp8_triton(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + ref_out = fp8_paged_mqa_logits_deepgemm( + q_fp8, + kv_cache_fp8_deepgemm, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + ref_out = to_reference(ref_out) + + with flag_gems.use_gems(): + res_out = flag_gems.fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8_triton, + weights, + context_lens, + block_tables, + max_model_len, + ) + + mask = _build_mask( + context_lens, batch_size, next_n, max_model_len, flag_gems.device + ) + + ref_out = to_reference(ref_out) + res_out = to_reference(res_out) + mask = to_reference(mask) + + res_out_masked = torch.nan_to_num(res_out.masked_fill(~mask, 0), 0.0) + ref_out_masked = torch.nan_to_num(ref_out.masked_fill(~mask, 0), 0.0) + + gems_assert_close( + res_out_masked, + ref_out_masked, + res_out_masked.dtype, + equal_nan=True, + atol=5e-2, + reduce_dim=1, + ) + + diff = calc_diff(res_out_masked, ref_out_masked) + assert ( + diff < 1e-3 + ), f"Large discrepancy between Triton and DeepGEMM: {diff=} (expected < 1e-3)"