diff --git a/benchmarks/bench_fa3_comparison.py b/benchmarks/bench_fa3_comparison.py new file mode 100644 index 0000000000..2e63323536 --- /dev/null +++ b/benchmarks/bench_fa3_comparison.py @@ -0,0 +1,718 @@ +import torch +import random + +torch.manual_seed(42) +random.seed(42) +device = "cuda" +dtype = torch.float16 + +from flash_attn_interface import flash_attn_varlen_func as fa3_varlen_func +from flash_attn_interface import flash_attn_with_kvcache as fa3_kvcache_func +import flashinfer +from flashinfer.testing import ( + bench_gpu_time_with_cuda_event as bench_gpu_time_with_cupti, +) + +head_dim = 128 + + +def calc_tflops(batch_size, seq_len, num_qo_heads, head_dim, time_ms, causal=True): + """Calculate TFLOPS for attention. + + FLOPs = 4 * batch_size * seq_len^2 * num_heads * head_dim (for non-causal) + For causal, multiply by 0.5 + """ + flops = 4 * batch_size * seq_len * seq_len * num_qo_heads * head_dim + if causal: + flops = flops * 0.5 + tflops = flops / (time_ms / 1000) / 1e12 + return tflops + + +def calc_tflops_varlen(seq_lens, num_qo_heads, head_dim, time_ms, causal=True): + """Calculate TFLOPS for variable length attention.""" + total_flops = sum(4 * s * s * num_qo_heads * head_dim for s in seq_lens) + if causal: + total_flops = total_flops * 0.5 + tflops = total_flops / (time_ms / 1000) / 1e12 + return tflops + + +def bench_fn(fn): + """Benchmark a function and return median time in ms.""" + times = bench_gpu_time_with_cupti(fn, l2_flush=True) + return sorted(times)[len(times) // 2] # median + + +print("Comprehensive benchmark: FlashInfer vs FA3 (using CUPTI)") +print("=" * 115) + +# bs=1 tests +print("\n--- bs=1 Single Prefill ---") +print( + f"{'seq_len':<10} {'heads':<12} {'FlashInfer (ms)':<18} {'FA3 (ms)':<15} {'diff':<10} {'FI TFLOPS':<12} {'FA3 TFLOPS':<12}" +) +print("-" * 100) + +for seq_len in [512, 1024, 2048, 4096, 8192, 16384, 32768]: + for num_qo_heads, num_kv_heads in [(32, 8), (32, 32)]: + q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=dtype, device=device) + k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) + v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), + backend="fa3", + ) + wrapper.plan( + cu_seqlens, + cu_seqlens, + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + causal=True, + ) + + fi_time = bench_fn(lambda: wrapper.run(q, k, v)) + fa3_time = bench_fn( + lambda: fa3_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len, causal=True + ) + ) + + diff = (fi_time - fa3_time) / fa3_time * 100 + fi_tflops = calc_tflops( + 1, seq_len, num_qo_heads, head_dim, fi_time, causal=True + ) + fa3_tflops = calc_tflops( + 1, seq_len, num_qo_heads, head_dim, fa3_time, causal=True + ) + heads_str = f"{num_qo_heads}/{num_kv_heads}" + print( + f"{seq_len:<10} {heads_str:<12} {fi_time:<18.3f} {fa3_time:<15.3f} {diff:+.1f}%{'':5} {fi_tflops:<12.1f} {fa3_tflops:<12.1f}" + ) + +# Batch prefill tests +print("\n--- Batch Prefill ---") +print( + f"{'Config':<35} {'FlashInfer (ms)':<18} {'FA3 (ms)':<15} {'diff':<10} {'FI TFLOPS':<12} {'FA3 TFLOPS':<12}" +) +print("-" * 115) + +batch_configs = [ + (8, 512, 32, 8), + (8, 1024, 32, 8), + (8, 2048, 32, 8), + (8, 4096, 32, 8), + (8, 8192, 32, 8), + (8, 512, 32, 32), + (8, 1024, 32, 32), + (8, 2048, 32, 32), + (8, 4096, 32, 32), + (8, 8192, 32, 32), + (4, 16384, 32, 8), + (4, 16384, 32, 32), + (2, 32768, 32, 8), + (2, 32768, 32, 32), +] + +for batch_size, seq_len, num_qo_heads, num_kv_heads in batch_configs: + qo_lens = [seq_len] * batch_size + total_q = sum(qo_lens) + + q = torch.randn(total_q, num_qo_heads, head_dim, dtype=dtype, device=device) + k = torch.randn(total_q, num_kv_heads, head_dim, dtype=dtype, device=device) + v = torch.randn(total_q, num_kv_heads, head_dim, dtype=dtype, device=device) + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(qo_lens), 0).numpy()), + dtype=torch.int32, + device=device, + ) + + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), backend="fa3" + ) + wrapper.plan( + cu_seqlens, + cu_seqlens, + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + causal=True, + ) + + fi_time = bench_fn(lambda: wrapper.run(q, k, v)) + fa3_time = bench_fn( + lambda: fa3_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len, causal=True + ) + ) + + config_str = f"bs={batch_size}, seq={seq_len}, h={num_qo_heads}/{num_kv_heads}" + diff = (fi_time - fa3_time) / fa3_time * 100 + fi_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fi_time, causal=True + ) + fa3_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fa3_time, causal=True + ) + print( + f"{config_str:<35} {fi_time:<18.3f} {fa3_time:<15.3f} {diff:+.1f}%{'':5} {fi_tflops:<12.1f} {fa3_tflops:<12.1f}" + ) + + +# Variable sequence length tests +print("\n--- Variable Sequence Length Batch Prefill ---") +print( + f"{'Config':<40} {'FlashInfer (ms)':<18} {'FA3 (ms)':<15} {'diff':<10} {'FI TFLOPS':<12} {'FA3 TFLOPS':<12}" +) +print("-" * 120) + +varlen_configs = [ + # (batch_size, min_len, max_len, num_qo_heads, num_kv_heads) + (16, 64, 512, 32, 8), + (16, 128, 1024, 32, 8), + (16, 256, 2048, 32, 8), + (8, 512, 4096, 32, 8), + (4, 1024, 8192, 32, 8), + (4, 2048, 16384, 32, 8), + (2, 4096, 32768, 32, 8), + (16, 64, 512, 32, 32), + (16, 128, 1024, 32, 32), + (16, 256, 2048, 32, 32), + (8, 512, 4096, 32, 32), + (4, 1024, 8192, 32, 32), + (4, 2048, 16384, 32, 32), +] + +for batch_size, min_len, max_len, num_qo_heads, num_kv_heads in varlen_configs: + seq_lens = [random.randint(min_len, max_len) for _ in range(batch_size)] + total_tokens = sum(seq_lens) + max_seqlen = max(seq_lens) + + q = torch.randn(total_tokens, num_qo_heads, head_dim, dtype=dtype, device=device) + k = torch.randn(total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device) + v = torch.randn(total_tokens, num_kv_heads, head_dim, dtype=dtype, device=device) + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seq_lens), 0).numpy()), + dtype=torch.int32, + device=device, + ) + + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), backend="fa3" + ) + wrapper.plan( + cu_seqlens, + cu_seqlens, + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + causal=True, + ) + + fi_time = bench_fn(lambda: wrapper.run(q, k, v)) + fa3_time = bench_fn( + lambda: fa3_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=True + ) + ) + + fi_tflops = calc_tflops_varlen( + seq_lens, num_qo_heads, head_dim, fi_time, causal=True + ) + fa3_tflops = calc_tflops_varlen( + seq_lens, num_qo_heads, head_dim, fa3_time, causal=True + ) + + config_str = ( + f"bs={batch_size}, len=[{min_len}-{max_len}], h={num_qo_heads}/{num_kv_heads}" + ) + diff = (fi_time - fa3_time) / fa3_time * 100 + print( + f"{config_str:<40} {fi_time:<18.3f} {fa3_time:<15.3f} {diff:+.1f}%{'':5} {fi_tflops:<12.1f} {fa3_tflops:<12.1f}" + ) + +# FP8 tests +print("\n--- FP8 Batch Prefill ---") +print( + f"{'Config':<35} {'FlashInfer (ms)':<18} {'FA3 (ms)':<15} {'diff':<10} {'FI TFLOPS':<12} {'FA3 TFLOPS':<12}" +) +print("-" * 115) + +fp8_dtype = torch.float8_e4m3fn + + +def per_head_symmetric_quant(x, quant_dtype): + """Per-head symmetric quantization to FP8.""" + o_min_val, o_max_val = ( + (-448.0, 448.0) if quant_dtype == torch.float8_e4m3fn else (-57344, 57344) + ) + x_max_val = x.abs().amax(dim=(0, 2)).to(dtype=torch.float32) + s_out = torch.clamp(x_max_val / o_max_val, min=1e-6) + s_out_broadcast = s_out.view(1, -1, 1) + q_x_out = torch.clamp(x / s_out_broadcast, min=o_min_val, max=o_max_val).to( + dtype=quant_dtype + ) + return q_x_out, s_out + + +fp8_configs = [ + (8, 2048, 32, 8), + (8, 4096, 32, 8), + (8, 8192, 32, 8), + (4, 16384, 32, 8), +] + +for batch_size, seq_len, num_qo_heads, num_kv_heads in fp8_configs: + qo_lens = [seq_len] * batch_size + total_q = sum(qo_lens) + + # Create FP16 tensors first + q_fp16 = torch.randn( + total_q, num_qo_heads, head_dim, dtype=torch.float16, device=device + ) + k_fp16 = torch.randn( + total_q, num_kv_heads, head_dim, dtype=torch.float16, device=device + ) + v_fp16 = torch.randn( + total_q, num_kv_heads, head_dim, dtype=torch.float16, device=device + ) + + # Quantize to FP8 with proper scaling + q_fp8, s_q = per_head_symmetric_quant(q_fp16, fp8_dtype) + k_fp8, s_k = per_head_symmetric_quant(k_fp16, fp8_dtype) + v_fp8, s_v = per_head_symmetric_quant(v_fp16, fp8_dtype) + + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(qo_lens), 0).numpy()), + dtype=torch.int32, + device=device, + ) + + config_str = f"bs={batch_size}, seq={seq_len}, h={num_qo_heads}/{num_kv_heads}" + + # Benchmark FlashInfer FP8 + fi_time = None + fi_tflops = None + try: + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), + backend="fa3", + ) + wrapper.plan( + cu_seqlens, + cu_seqlens, + num_qo_heads, + num_kv_heads, + head_dim, + head_dim, + causal=True, + q_data_type=fp8_dtype, + kv_data_type=fp8_dtype, + o_data_type=torch.float16, # Output is FP16 + ) + fi_time = bench_fn(lambda: wrapper.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v)) + fi_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fi_time, causal=True + ) + except Exception: + fi_time = None + fi_tflops = None + + # Benchmark FA3 FP8 + fa3_time = None + fa3_tflops = None + try: + fa3_time = bench_fn( + lambda: fa3_varlen_func( + q_fp8, + k_fp8, + v_fp8, + cu_seqlens, + cu_seqlens, + seq_len, + seq_len, + causal=True, + ) + ) + fa3_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fa3_time, causal=True + ) + except Exception: + fa3_time = None + fa3_tflops = None + + if fi_time is not None and fa3_time is not None: + diff = (fi_time - fa3_time) / fa3_time * 100 + print( + f"{config_str:<35} {fi_time:<18.3f} {fa3_time:<15.3f} {diff:+.1f}%{'':5} {fi_tflops:<12.1f} {fa3_tflops:<12.1f}" + ) + elif fi_time is not None: + print( + f"{config_str:<35} {fi_time:<18.3f} {'N/A':<15} {'N/A':<10} {fi_tflops:<12.1f} {'N/A':<12}" + ) + elif fa3_time is not None: + print( + f"{config_str:<35} {'N/A':<18} {fa3_time:<15.3f} {'N/A':<10} {'N/A':<12} {fa3_tflops:<12.1f}" + ) + else: + print( + f"{config_str:<35} {'N/A':<18} {'N/A':<15} {'N/A':<10} {'N/A':<12} {'N/A':<12}" + ) + +# FP16 Paged KV Cache tests +print("\n--- FP16 Paged KV Cache Prefill ---") +print( + f"{'Config':<45} {'FlashInfer (ms)':<18} {'FA3 (ms)':<15} {'diff':<10} {'FI TFLOPS':<12} {'FA3 TFLOPS':<12}" +) +print("-" * 125) + +fp16_paged_configs = [ + # (batch_size, seq_len, num_qo_heads, num_kv_heads, page_size) + # page_size=1 + (8, 2048, 32, 8, 1), + (8, 4096, 32, 8, 1), + (8, 8192, 32, 8, 1), + (4, 16384, 32, 8, 1), + # page_size=16 + (8, 2048, 32, 8, 16), + (8, 4096, 32, 8, 16), + (8, 8192, 32, 8, 16), + (4, 16384, 32, 8, 16), +] + +for batch_size, seq_len, num_qo_heads, num_kv_heads, page_size in fp16_paged_configs: + qo_lens = [seq_len] * batch_size + kv_lens = [seq_len] * batch_size + total_q = sum(qo_lens) + total_kv_pages = sum((kv_len + page_size - 1) // page_size for kv_len in kv_lens) + + # FP16 tensors + q_fp16 = torch.randn( + total_q, num_qo_heads, head_dim, dtype=torch.float16, device=device + ) + + # Paged KV cache: (num_pages, 2, page_size, num_kv_heads, head_dim) + kv_data_fp16 = torch.randn( + total_kv_pages, + 2, + page_size, + num_kv_heads, + head_dim, + dtype=torch.float16, + device=device, + ) + + # Page indices for each request + kv_indptr = torch.tensor( + [0] + + [ + sum((kv_lens[i] + page_size - 1) // page_size for i in range(j + 1)) + for j in range(batch_size) + ], + dtype=torch.int32, + device=device, + ) + kv_indices = torch.arange(total_kv_pages, dtype=torch.int32, device=device) + kv_last_page_len = torch.tensor( + [((kv_len - 1) % page_size) + 1 for kv_len in kv_lens], + dtype=torch.int32, + device=device, + ) + qo_indptr = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(qo_lens), 0).numpy()), + dtype=torch.int32, + device=device, + ) + + config_str = f"bs={batch_size}, seq={seq_len}, h={num_qo_heads}/{num_kv_heads}, page={page_size}" + + # Benchmark FlashInfer FP16 Paged + fi_time = None + fi_tflops = None + try: + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), + "NHD", + backend="fa3", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=True, + q_data_type=torch.float16, + ) + fi_time = bench_fn(lambda: wrapper.run(q_fp16, kv_data_fp16)) + fi_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fi_time, causal=True + ) + except Exception as e: + print(f"FlashInfer error: {e}") + fi_time = None + fi_tflops = None + + # FA3 paged attention + fa3_time = None + fa3_tflops = None + try: + num_pages_per_seq = (seq_len + page_size - 1) // page_size + max_num_blocks_per_seq = num_pages_per_seq + + # Create FA3 paged KV cache: (num_blocks, page_size, num_kv_heads, head_dim) + k_cache_fa3 = kv_data_fp16[ + :, 0, :, :, : + ] # (num_pages, page_size, num_kv_heads, head_dim) + v_cache_fa3 = kv_data_fp16[:, 1, :, :, :] + + # Create page table: (batch_size, max_num_blocks_per_seq) + page_table = torch.zeros( + batch_size, max_num_blocks_per_seq, dtype=torch.int32, device=device + ) + for b in range(batch_size): + start_page = b * num_pages_per_seq + for p in range(num_pages_per_seq): + page_table[b, p] = start_page + p + + # Q for FA3: (batch_size, seq_len, num_qo_heads, head_dim) + q_fa3 = q_fp16.reshape(batch_size, seq_len, num_qo_heads, head_dim) + + # cache_seqlens + cache_seqlens = torch.full( + (batch_size,), seq_len, dtype=torch.int32, device=device + ) + + fa3_time = bench_fn( + lambda: fa3_kvcache_func( + q_fa3, + k_cache_fa3, + v_cache_fa3, + cache_seqlens=cache_seqlens, + page_table=page_table, + causal=True, + ) + ) + fa3_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fa3_time, causal=True + ) + except Exception as e: + print(f"FA3 paged error: {e}") + fa3_time = None + fa3_tflops = None + + if fi_time is not None and fa3_time is not None: + diff = (fi_time - fa3_time) / fa3_time * 100 + print( + f"{config_str:<45} {fi_time:<18.3f} {fa3_time:<15.3f} {diff:>+.1f}%{'':<4} {fi_tflops:<12.1f} {fa3_tflops:<12.1f}" + ) + elif fi_time is not None: + print( + f"{config_str:<45} {fi_time:<18.3f} {'N/A':<15} {'N/A':<10} {fi_tflops:<12.1f} {'N/A':<12}" + ) + else: + print( + f"{config_str:<45} {'N/A':<18} {'N/A':<15} {'N/A':<10} {'N/A':<12} {'N/A':<12}" + ) + +# FP8 Paged KV Cache tests +print("\n--- FP8 Paged KV Cache Prefill ---") +print( + f"{'Config':<45} {'FlashInfer (ms)':<18} {'FA3 (ms)':<15} {'diff':<10} {'FI TFLOPS':<12} {'FA3 TFLOPS':<12}" +) +print("-" * 125) + +fp8_paged_configs = [ + # (batch_size, seq_len, num_qo_heads, num_kv_heads, page_size) + # page_size=1 + (8, 2048, 32, 8, 1), + (8, 4096, 32, 8, 1), + (8, 8192, 32, 8, 1), + (4, 16384, 32, 8, 1), + # page_size=16 + (8, 2048, 32, 8, 16), + (8, 4096, 32, 8, 16), + (8, 8192, 32, 8, 16), + (4, 16384, 32, 8, 16), +] + +for batch_size, seq_len, num_qo_heads, num_kv_heads, page_size in fp8_paged_configs: + qo_lens = [seq_len] * batch_size + kv_lens = [seq_len] * batch_size + total_q = sum(qo_lens) + total_kv_pages = sum((kv_len + page_size - 1) // page_size for kv_len in kv_lens) + + # Create FP16 tensors first + q_fp16 = torch.randn( + total_q, num_qo_heads, head_dim, dtype=torch.float16, device=device + ) + + # Paged KV cache: (num_pages, 2, page_size, num_kv_heads, head_dim) + kv_data_fp16 = torch.randn( + total_kv_pages, + 2, + page_size, + num_kv_heads, + head_dim, + dtype=torch.float16, + device=device, + ) + + # Page indices for each request + kv_indptr = torch.tensor( + [0] + + [ + sum((kv_lens[i] + page_size - 1) // page_size for i in range(j + 1)) + for j in range(batch_size) + ], + dtype=torch.int32, + device=device, + ) + kv_indices = torch.arange(total_kv_pages, dtype=torch.int32, device=device) + kv_last_page_len = torch.tensor( + [((kv_len - 1) % page_size) + 1 for kv_len in kv_lens], + dtype=torch.int32, + device=device, + ) + qo_indptr = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(qo_lens), 0).numpy()), + dtype=torch.int32, + device=device, + ) + + # Quantize Q to FP8 + q_fp8, s_q = per_head_symmetric_quant(q_fp16, fp8_dtype) + + # For paged KV, we need to quantize differently + k_fp16 = kv_data_fp16[:, 0, :, :, :].reshape(-1, num_kv_heads, head_dim) + v_fp16 = kv_data_fp16[:, 1, :, :, :].reshape(-1, num_kv_heads, head_dim) + k_fp8, s_k = per_head_symmetric_quant(k_fp16, fp8_dtype) + v_fp8, s_v = per_head_symmetric_quant(v_fp16, fp8_dtype) + + # Reshape back to paged format + kv_data_fp8 = torch.stack( + [ + k_fp8.reshape(total_kv_pages, page_size, num_kv_heads, head_dim), + v_fp8.reshape(total_kv_pages, page_size, num_kv_heads, head_dim), + ], + dim=1, + ) + + config_str = f"bs={batch_size}, seq={seq_len}, h={num_qo_heads}/{num_kv_heads}, page={page_size}" + + # Benchmark FlashInfer FP8 Paged + fi_time = None + fi_tflops = None + try: + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device), + "NHD", + backend="fa3", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=True, + q_data_type=fp8_dtype, + kv_data_type=fp8_dtype, + o_data_type=torch.float16, + ) + fi_time = bench_fn(lambda: wrapper.run(q_fp8, kv_data_fp8, s_q, s_k, s_v)) + fi_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fi_time, causal=True + ) + except Exception as e: + print(f"FlashInfer error: {e}") + fi_time = None + fi_tflops = None + + # FA3 paged attention + fa3_time = None + fa3_tflops = None + try: + # FA3 paged format: (num_blocks, page_size, num_kv_heads, head_dim) + num_pages_per_seq = (seq_len + page_size - 1) // page_size + max_num_blocks_per_seq = num_pages_per_seq + + # Create FA3 paged KV cache + k_cache_fa3 = k_fp8.reshape(total_kv_pages, page_size, num_kv_heads, head_dim) + v_cache_fa3 = v_fp8.reshape(total_kv_pages, page_size, num_kv_heads, head_dim) + + # Create page table: (batch_size, max_num_blocks_per_seq) + page_table = torch.zeros( + batch_size, max_num_blocks_per_seq, dtype=torch.int32, device=device + ) + for b in range(batch_size): + start_page = b * num_pages_per_seq + for p in range(num_pages_per_seq): + page_table[b, p] = start_page + p + + # Q for FA3: (batch_size, seq_len, num_qo_heads, head_dim) + q_fa3 = q_fp8.reshape(batch_size, seq_len, num_qo_heads, head_dim) + + # cache_seqlens: actual sequence lengths + cache_seqlens = torch.full( + (batch_size,), seq_len, dtype=torch.int32, device=device + ) + + # descale tensors for FP8 + # FA3 expects per-head descale: shape (batch_size, num_kv_heads) for GQA + k_descale_fa3 = s_k.squeeze().unsqueeze(0).expand(batch_size, -1).contiguous() + v_descale_fa3 = s_v.squeeze().unsqueeze(0).expand(batch_size, -1).contiguous() + # q_descale should also be (batch_size, num_kv_heads) - one scale per kv head group + q_descale_fa3 = ( + s_q.squeeze() + .reshape(num_kv_heads, num_qo_heads // num_kv_heads) + .mean(dim=1) + ) + q_descale_fa3 = q_descale_fa3.unsqueeze(0).expand(batch_size, -1).contiguous() + + fa3_time = bench_fn( + lambda: fa3_kvcache_func( + q_fa3, + k_cache_fa3, + v_cache_fa3, + cache_seqlens=cache_seqlens, + page_table=page_table, + q_descale=q_descale_fa3, + k_descale=k_descale_fa3, + v_descale=v_descale_fa3, + causal=True, + ) + ) + fa3_tflops = calc_tflops( + batch_size, seq_len, num_qo_heads, head_dim, fa3_time, causal=True + ) + except Exception as e: + print(f"FA3 paged error: {e}") + fa3_time = None + fa3_tflops = None + + if fi_time is not None and fa3_time is not None: + diff = (fi_time - fa3_time) / fa3_time * 100 + print( + f"{config_str:<45} {fi_time:<18.3f} {fa3_time:<15.3f} {diff:>+.1f}%{'':<4} {fi_tflops:<12.1f} {fa3_tflops:<12.1f}" + ) + elif fi_time is not None: + print( + f"{config_str:<45} {fi_time:<18.3f} {'N/A':<15} {'N/A':<10} {fi_tflops:<12.1f} {'N/A':<12}" + ) + else: + print( + f"{config_str:<45} {'N/A':<18} {'N/A':<15} {'N/A':<10} {'N/A':<12} {'N/A':<12}" + ) diff --git a/include/flashinfer/attention/hopper/kernel_traits.cuh b/include/flashinfer/attention/hopper/kernel_traits.cuh index abf164f61d..fd2ded0b6d 100644 --- a/include/flashinfer/attention/hopper/kernel_traits.cuh +++ b/include/flashinfer/attention/hopper/kernel_traits.cuh @@ -17,6 +17,7 @@ #include "cutlass/layout/layout.h" #include "cutlass/numeric_types.h" #include "cutlass/pipeline/pipeline.hpp" +#include "sm90_pipeline_no_cluster.cuh" namespace flashinfer { @@ -110,8 +111,10 @@ struct AttentionKernelTraits { GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{}))); + // Use PipelineTmaAsyncNoCluster for TMA loads to avoid perf regression in Cutlass 3.6+ + // Only 1 out of 128 threads signals the barrier (instead of all threads) using MainloopPipeline = - std::conditional_t, + std::conditional_t, typename cutlass::PipelineAsync>; using PipelineState = typename cutlass::PipelineState; diff --git a/include/flashinfer/attention/hopper/mainloop.cuh b/include/flashinfer/attention/hopper/mainloop.cuh index e5bf4ffb9f..e390d312ec 100644 --- a/include/flashinfer/attention/hopper/mainloop.cuh +++ b/include/flashinfer/attention/hopper/mainloop.cuh @@ -202,7 +202,7 @@ struct CollectiveMainloop { if (lane_predicate) { pipeline_k.producer_acquire(smem_pipe_write_k); copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), - /*mcast_mask=*/0), + /*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST), tKgK(_, kv_tile_idx), tKsK(_, smem_pipe_write_k.index())); ++smem_pipe_write_k; } @@ -230,14 +230,16 @@ struct CollectiveMainloop { #pragma unroll 2 for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { pipeline_k.producer_acquire(smem_pipe_write_k); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), - /*mcast_mask=*/0), - tKgK(_, kv_tile_idx - 1), tKsK(_, smem_pipe_write_k.index())); + copy( + mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), + /*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST), + tKgK(_, kv_tile_idx - 1), tKsK(_, smem_pipe_write_k.index())); ++smem_pipe_write_k; pipeline_v.producer_acquire(smem_pipe_write_v); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), - /*mcast_mask=*/0), - tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index())); + copy( + mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), + /*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST), + tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index())); ++smem_pipe_write_v; } } @@ -245,7 +247,7 @@ struct CollectiveMainloop { if (lane_predicate) { pipeline_v.producer_acquire(smem_pipe_write_v); copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), - /*mcast_mask=*/0), + /*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST), tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index())); ++smem_pipe_write_v; } diff --git a/include/flashinfer/attention/hopper/mainloop_mma.cuh b/include/flashinfer/attention/hopper/mainloop_mma.cuh index 27522f3187..89dc4581eb 100644 --- a/include/flashinfer/attention/hopper/mainloop_mma.cuh +++ b/include/flashinfer/attention/hopper/mainloop_mma.cuh @@ -56,6 +56,10 @@ CUTLASS_DEVICE void mma_f16( Tensor tSrK = threadMmaQK.partition_fragment_B(sK); Tensor tOrV = threadMmaPV.partition_fragment_B(sVt); + // Create identity tensor once, outside loops + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); + Tensor tScS = threadMmaQK.partition_C(cS); + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -95,44 +99,43 @@ CUTLASS_DEVICE void mma_f16( auto col_limit_left = [&](int qo_idx) { return qo_idx + kv_len - qo_len - mainloop_params.window_left; }; - auto mask_multi_item_scoring = [&](decltype(tSrS)& tSrS, int i, int qo_idx, int kv_idx) { + + // Multi-item scoring mask functions + auto mask_multi_item_scoring = [&](auto& tSrS_ref, int i, int qo_idx, int kv_idx) { const uint32_t idx_in_original_seq = qo_idx + kv_len - qo_len; const bool out_of_boundary = kv_idx > idx_in_original_seq || (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))); const bool is_prefix = idx_in_original_seq < prefix_len; uint16_t token_pos_in_items_regs = 0; - // Only access idx_in_original_seq >= prefix_len && idx_in_original_seq < kv_len to avoid - // out-of-bounds memory access if (idx_in_original_seq >= prefix_len & idx_in_original_seq < kv_len) { token_pos_in_items_regs = __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); } if (out_of_boundary || is_prefix) { - tSrS(i) = out_of_boundary ? (AttentionUpdater::fill_value) : tSrS(i); + tSrS_ref(i) = out_of_boundary ? (AttentionUpdater::fill_value) : tSrS_ref(i); } else { - tSrS(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) - ? tSrS(i) - : (AttentionUpdater::fill_value); + tSrS_ref(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) + ? tSrS_ref(i) + : (AttentionUpdater::fill_value); } }; - auto mask_multi_item_scoring_assume_in_bound = [&](decltype(tSrS)& tSrS, int i, int qo_idx, + + auto mask_multi_item_scoring_assume_in_bound = [&](auto& tSrS_ref, int i, int qo_idx, int kv_idx) { const uint32_t idx_in_original_seq = qo_idx + kv_len - qo_len; const bool is_prefix = idx_in_original_seq < prefix_len; if (is_prefix) { - tSrS(i) = AttentionUpdater::fill_value; + tSrS_ref(i) = AttentionUpdater::fill_value; } else { uint16_t token_pos_in_items_regs = 0; - // Only access idx_in_original_seq >= prefix_len && idx_in_original_seq < kv_len to avoid - // out-of-bounds memory access if (idx_in_original_seq >= prefix_len & idx_in_original_seq < kv_len) { token_pos_in_items_regs = __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); } - - tSrS(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) - ? tSrS(i) - : (AttentionUpdater::fill_value); + tSrS_ref(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) + ? tSrS_ref(i) + : (AttentionUpdater::fill_value); } }; + auto kv_tile_idx_decrement = [&](int kv_tile_idx) { int result = kv_tile_idx - 1; if constexpr (MULTIITEMSCORING) { @@ -143,175 +146,172 @@ CUTLASS_DEVICE void mma_f16( } return result; }; - { - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); - Tensor tScS = threadMmaQK.partition_C(cS); + + // ============================================================================ + // Compile-time specialized mask functions (FA3 style) + // ============================================================================ + + // Causal mask with seqlen check (first iteration) + auto causal_mask_with_seqlen_fn = [&](auto& tSrS_local, int kv_tile) { #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { + for (int i = 0; i < size(tSrS_local); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; - int kv_idx = get<1>(tScS(i)) + kv_tile_idx * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, - qo_head_idx, kv_head_idx); + int kv_idx = get<1>(tScS(i)) + kv_tile * CTA_KV; + tSrS_local(i) = variant.LogitsTransform(mainloop_params, tSrS_local(i), /*batch_idx=*/0, + qo_idx, kv_idx, qo_head_idx, kv_head_idx); if constexpr (MULTIITEMSCORING) { - mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); - } else if constexpr (!CAUSAL) { // Just masking based on col + mask_multi_item_scoring(tSrS_local, i, qo_idx, kv_idx); + } else if constexpr (!CAUSAL) { if (kv_idx >= kv_len) { - tSrS(i) = AttentionUpdater::fill_value; + tSrS_local(i) = AttentionUpdater::fill_value; } } else { if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { - tSrS(i) = AttentionUpdater::fill_value; + tSrS_local(i) = AttentionUpdater::fill_value; } } if constexpr (LEFT_SLIDING_WINDOW) { if (kv_idx < col_limit_left(qo_idx)) { - tSrS(i) = AttentionUpdater::fill_value; + tSrS_local(i) = AttentionUpdater::fill_value; } } } - } - - attention_updater.update(tSrS); - Tensor tOrP = make_tensor(convert_type(tSrS).data(), - convert_layout_acc_Aregs(tSrS.layout())); + }; - constexpr int n_masking_steps = MULTIITEMSCORING ? (cute::ceil_div(CTA_Q, CTA_KV) + 1) - : (CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0); - // masking loops - // ziangl@nvidia.com: for multi item scoring, we use this loop only to mask along the diagonal + // Causal mask without seqlen check (masking iterations) + auto causal_mask_fn = [&](auto& tSrS_local, int kv_tile) { #pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; - ++masking_step, kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { - Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - WarpScheduler::barrier_sync(); - gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), - tSrS); - if (masking_step > 0) { - attention_updater.rescale_o(tOrO); - } - consumer_wait(pipeline_v, smem_pipe_read_v); - gemm(tiled_mma_pv, tOrP, - tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - WarpScheduler::barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read_k); // release K - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); - Tensor tScS = threadMmaQK.partition_C(cS); -#pragma unroll - for (int i = 0; i < size(tSrS); ++i) { + for (int i = 0; i < size(tSrS_local); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; - int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, - qo_head_idx, kv_head_idx); - if (MULTIITEMSCORING) { - mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); + int kv_idx = get<1>(tScS(i)) + kv_tile * CTA_KV; + tSrS_local(i) = variant.LogitsTransform(mainloop_params, tSrS_local(i), /*batch_idx=*/0, + qo_idx, kv_idx, qo_head_idx, kv_head_idx); + if constexpr (MULTIITEMSCORING) { + mask_multi_item_scoring(tSrS_local, i, qo_idx, kv_idx); } else { if (kv_idx >= col_limit_right(qo_idx)) { - tSrS(i) = AttentionUpdater::fill_value; + tSrS_local(i) = AttentionUpdater::fill_value; } } if constexpr (LEFT_SLIDING_WINDOW) { if (kv_idx < col_limit_left(qo_idx)) { - tSrS(i) = AttentionUpdater::fill_value; + tSrS_local(i) = AttentionUpdater::fill_value; } } } - attention_updater.update(tSrS); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V - ++smem_pipe_read_k; - ++smem_pipe_read_v; - cute::copy(make_tensor(convert_type(tSrS).data(), - convert_layout_acc_Aregs(tSrS.layout())), - tOrP); - } + }; -#pragma unroll 1 - for (; kv_tile_idx > swa_end_kv_tile_idx + 1; kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { - Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - WarpScheduler::barrier_sync(); - gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), - tSrS); - attention_updater.rescale_o(tOrO); - consumer_wait(pipeline_v, smem_pipe_read_v); - gemm(tiled_mma_pv, tOrP, - tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - WarpScheduler::barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read_k); // release K - // #pragma unroll - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); - Tensor tScS = threadMmaQK.partition_C(cS); + // No mask function (main loop - no causal boundary) + auto no_mask_fn = [&](auto& tSrS_local, int kv_tile) { #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { + for (int i = 0; i < size(tSrS_local); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; - int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, - qo_head_idx, kv_head_idx); + int kv_idx = get<1>(tScS(i)) + kv_tile * CTA_KV; + tSrS_local(i) = variant.LogitsTransform(mainloop_params, tSrS_local(i), /*batch_idx=*/0, + qo_idx, kv_idx, qo_head_idx, kv_head_idx); } if constexpr (MULTIITEMSCORING) { - // auto nums_tiles_outside_causal_diagonal = kv_tile_idx_count - cute::ceil_div(CTA_Q, - // CTA_KV); - if (kv_tile_idx >= num_kv_tiles_prefix - 1) { + if (kv_tile >= num_kv_tiles_prefix - 1) { #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { + for (int i = 0; i < size(tSrS_local); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; - int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; - mask_multi_item_scoring_assume_in_bound(tSrS, i, qo_idx, kv_idx); + int kv_idx = get<1>(tScS(i)) + kv_tile * CTA_KV; + mask_multi_item_scoring_assume_in_bound(tSrS_local, i, qo_idx, kv_idx); } } } - attention_updater.update(tSrS); + }; + + // Sliding window left mask function + auto swa_left_mask_fn = [&](auto& tSrS_local, int kv_tile) { +#pragma unroll + for (int i = 0; i < size(tSrS_local); ++i) { + int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; + int kv_idx = get<1>(tScS(i)) + kv_tile * CTA_KV; + tSrS_local(i) = variant.LogitsTransform(mainloop_params, tSrS_local(i), /*batch_idx=*/0, + qo_idx, kv_idx, qo_head_idx, kv_head_idx); + if (kv_idx < col_limit_left(qo_idx)) { + tSrS_local(i) = AttentionUpdater::fill_value; + } + } + }; + + // ============================================================================ + // First iteration (with seqlen check) + // ============================================================================ + causal_mask_with_seqlen_fn(tSrS, kv_tile_idx); + attention_updater.update(tSrS); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())); + + // ============================================================================ + // Forward step with compile-time specialized mask function + // ============================================================================ + auto fwd_step = [&](int kv_tile, auto mask_fn, auto is_first_type) { + static constexpr bool Is_first = decltype(is_first_type)::value; + + Tensor tSrS_local = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + WarpScheduler::barrier_sync(); + gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), + tSrS_local); + if constexpr (!Is_first) { + attention_updater.rescale_o(tOrO); + } + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm(tiled_mma_pv, tOrP, + tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + WarpScheduler::barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); + + // Apply mask function (compile-time specialized) + mask_fn(tSrS_local, kv_tile); + + attention_updater.template update(tSrS_local); warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V + pipeline_v.consumer_release(smem_pipe_read_v); ++smem_pipe_read_k; ++smem_pipe_read_v; - cute::copy(make_tensor(convert_type(tSrS).data(), - convert_layout_acc_Aregs(tSrS.layout())), - tOrP); + cute::copy( + make_tensor(convert_type(tSrS_local).data(), + convert_layout_acc_Aregs(tSrS_local.layout())), + tOrP); + }; + + constexpr int n_masking_steps = MULTIITEMSCORING ? (cute::ceil_div(CTA_Q, CTA_KV) + 1) + : (CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0); + + // ============================================================================ + // Masking loop (causal boundary iterations) + // ============================================================================ +#pragma unroll 1 + for (int masking_step = 0; masking_step < n_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; + ++masking_step, kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { + fwd_step(kv_tile_idx_decrement(kv_tile_idx), causal_mask_fn, cute::false_type{}); + } + + // ============================================================================ + // Main loop (no causal masking needed) + // ============================================================================ +#pragma unroll 1 + for (; kv_tile_idx > swa_end_kv_tile_idx + 1; kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { + fwd_step(kv_tile_idx_decrement(kv_tile_idx), no_mask_fn, cute::false_type{}); } + // ============================================================================ + // Sliding window left mask loop (if enabled) + // ============================================================================ if constexpr (LEFT_SLIDING_WINDOW) { #pragma unroll 1 for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { - Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - WarpScheduler::barrier_sync(); - gemm(tiled_mma_qk, tSrQ, - tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - attention_updater.rescale_o(tOrO); - consumer_wait(pipeline_v, smem_pipe_read_v); - gemm(tiled_mma_pv, tOrP, - tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - WarpScheduler::barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read_k); // release K - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); - Tensor tScS = threadMmaQK.partition_C(cS); -#pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; - int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; - tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, - qo_head_idx, kv_head_idx); - if (kv_idx < col_limit_left(qo_idx)) { - tSrS(i) = AttentionUpdater::fill_value; - } - } - attention_updater.update(tSrS); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V - ++smem_pipe_read_k; - ++smem_pipe_read_v; - cute::copy(make_tensor(convert_type(tSrS).data(), - convert_layout_acc_Aregs(tSrS.layout())), - tOrP); + fwd_step(kv_tile_idx - 1, swa_left_mask_fn, cute::false_type{}); } } - // Tell warp 0 that smem_q is ready + // ============================================================================ + // Epilogue: final V gemm + // ============================================================================ cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, /*id=*/static_cast(NamedBarriers::kQueryEmpty)); attention_updater.rescale_o(tOrO); @@ -320,7 +320,7 @@ CUTLASS_DEVICE void mma_f16( tOrO); attention_updater.finalize(tSrS, get_variant_scale_pv(variant)); warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang + pipeline_v.consumer_release(smem_pipe_read_v); ++smem_pipe_read_v; attention_updater.rescale_o(tOrO); diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index f1e441a53b..eb42879646 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -298,7 +298,8 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaS using CollectiveMainloop = CollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; - using Scheduler = SingleTileScheduler; + // Use LPT scheduling for causal attention for better load balancing + using Scheduler = SingleTileScheduler; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {params.q_ptr, get_gmem_layout(params.qo_len, params.num_qo_heads, KernelTraits::HEAD_DIM_QK, diff --git a/include/flashinfer/attention/hopper/quantization/epilogue.cuh b/include/flashinfer/attention/hopper/quantization/epilogue.cuh index 8bf5098d45..d0bf4969be 100644 --- a/include/flashinfer/attention/hopper/quantization/epilogue.cuh +++ b/include/flashinfer/attention/hopper/quantization/epilogue.cuh @@ -14,6 +14,7 @@ #include "../named_barrier.cuh" #include "../utils.cuh" #include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" namespace flashinfer { @@ -39,7 +40,13 @@ struct FP8CollectiveEpilogue { decltype(cute::get<2>(TileShape_QKD{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{}))); - using SmemCopyAtomO = Copy_Atom; + using StrideO = cute::Shape; + using EpilogueTile_MN = decltype(select<0, 2>(TileShape_QKD{})); + // Use sm90_get_smem_store_op_for_accumulator to get the correct copy op for FP8 accumulators + using CopyOpR2S = + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator< + StrideO, DTypeO, EpilogueTile_MN>()); + using SmemCopyAtomO = Copy_Atom; using SharedStorage = cute::array_aligned>; using ShapeT = cute::Shape; diff --git a/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh b/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh index da5b3da964..bdb25ee783 100644 --- a/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh +++ b/include/flashinfer/attention/hopper/quantization/kernel_traits.cuh @@ -47,8 +47,16 @@ struct SharedStorageQKVOVt { }; /* - In-kernel FP8 transpose adopted from FlashAttention-3 template - https://github.com/Dao-AILab/flash-attention/blob/c7f32a8409e52a84bd8046afe7060da33036f9a5/hopper/kernel_traits.h#L217 + FA3-style FP8 transpose: Same-shape transpose with MN-major TMA load and K-major MMA + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp + + Key insight: TMA loads V with transposed gmem strides into MN-major smem layout. + Then we transpose in-place from MN-major to K-major within the same-shape buffer. + Both SmemLayoutVtTma and SmemLayoutVtMma have shape (HEAD_DIM, CTA_KV, STAGES). + + For sparse path (cp.async loading), we keep the original (CTA_KV, HEAD_DIM, STAGES) layout + since cp.async loads V directly in its original gmem layout (N, D). */ template struct TranposeTraits_64x64 { @@ -56,49 +64,192 @@ struct TranposeTraits_64x64 { using TransElement = Element; static_assert(cutlass::sizeof_bits_v == 8); - using SmemShapeLDSM = Shape, Shape<_16, _4>>; - using SmemShapeSTSM = Shape, Shape<_16, _4>>; + static constexpr int kHeadDim = get<2>(TileShape_QKD{}); + static constexpr int kBlockN = get<1>(TileShape_QKD{}); + + // MN-major for TMA loading (V is loaded with transposed gmem strides) + static constexpr cute::GMMA::Major TmaMajorV = GMMA::Major::MN; + // K-major for MMA consumption (required for FP8) + static constexpr cute::GMMA::Major MmaMajorV = GMMA::Major::K; + + // ==================== TMA Path Layouts (FA3-style same-shape) ==================== + // SmemLayoutVtTma: MN-major layout for TMA load, shape (HEAD_DIM, CTA_KV, STAGES) + using SmemLayoutAtomVtTma = + decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); + using SmemLayoutVtTma = decltype(tile_to_shape( + SmemLayoutAtomVtTma{}, make_shape(Int{}, Int{}, Int{}), + cute::Step<_2, _1, _3>{})); // MN-major ordering + + // SmemLayoutVtMma: K-major layout for MMA, same shape (HEAD_DIM, CTA_KV, STAGES) + using SmemLayoutAtomVtMma = + decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); + using SmemLayoutVtMma = decltype(tile_to_shape( + SmemLayoutAtomVtMma{}, make_shape(Int{}, Int{}, Int{}), + cute::Step<_1, _2, _3>{})); // K-major ordering + + // For TMA path: SmemLayoutV = SmemLayoutVtTma (MN-major, for TMA load) + using SmemLayoutV = SmemLayoutVtTma; + using SmemLayoutVt = SmemLayoutVtMma; + + // FA3-style LDSM/STSM tiled copies for TMA path transpose + static constexpr bool kHeadDimMultiple64 = kHeadDim % 64 == 0; + static_assert(kHeadDimMultiple64 || kBlockN % 64 == 0, + "Either kHeadDim or kBlockN must be multiple of 64"); + + using LDSM_thread_shape = + std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = + std::conditional_t, Stride<_4, _1, _0, _64>>; + using LDSM_value_shape = Shape<_2, _2, _1, _4>; + using LDSM_value_stride = Stride<_1, _2, _16, _4>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + + using S2RTiledCopyVt = decltype(make_tiled_copy(Copy_Atom{}, + Layout{}, + Layout{})); + + using STSM_thread_shape = + std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = + std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_value_shape = Shape<_1, _4, _2, _2>; + using STSM_value_stride = Stride<_0, _1, _4, _8>; + using STSM_divide_shape = Shape<_8, _16>; + + using R2STiledCopyV = decltype(make_tiled_copy(Copy_Atom{}, + Layout{}, + Layout{})); + + // TMA path transpose layouts + using SmemLayoutVTransposeSrc = SmemLayoutVtTma; + using SmemLayoutVtTransposeTgt = SmemLayoutVtMma; + + // ==================== Sparse Path Layouts (Original different-shape) ==================== + // For sparse path, cp.async loads V in original (N, D) layout, so we need (CTA_KV, HEAD_DIM, + // STAGES) + using SmemLayoutAtomVSparse = + decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtom_{})); + using SmemLayoutVSparse = decltype(tile_to_shape( + SmemLayoutAtomVSparse{}, make_shape(Int{}, Int{}, Int{}))); - using SmemLayoutAtomV = - decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtom_{})); - using SmemLayoutV = decltype(tile_to_shape( - SmemLayoutAtomV{}, - make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{}), Int{}))); - using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtom_{})); - using FactoringShapeV = - decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), - shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{}))); - using SmemLayoutVTransposeSrc = - decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); - - using SmemLayoutAtomVt = + // Sparse path transpose source layout (from SmemLayoutVSparse) + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using SmemLayoutDivideVSparse = + decltype(tiled_divide(SmemLayoutVSparse{}, TransposeShapeAtom_{})); + using FactoringShapeVSparse = decltype(make_shape( + SmemShapeLDSM{}, shape<1>(SmemLayoutDivideVSparse{}), shape<2>(SmemLayoutDivideVSparse{}), + shape<3>(SmemLayoutDivideVSparse{}))); + using SmemLayoutVSparseTransposeSrc = + decltype(composition(SmemLayoutDivideVSparse{}, make_layout(FactoringShapeVSparse{}))); + + // Sparse path transpose target layout (same SmemLayoutVt as TMA path for MMA) + using SmemLayoutAtomVtSparse = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtom_{})); - // k-major Vt as target layout. this changes the memory - using SmemLayoutVt = decltype(tile_to_shape( - SmemLayoutAtomVt{}, - make_shape(get<2>(TileShape_QKD{}), get<1>(TileShape_QKD{}), Int{}))); - using SmemLayoutVtTrans = decltype(composition( - SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{}))); - using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtom_{})); - using FactoringShapeVt = - decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), - shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{}))); - using SmemLayoutVtTransposeTgt = - decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + using SmemLayoutVtSparse = decltype(tile_to_shape( + SmemLayoutAtomVtSparse{}, make_shape(Int{}, Int{}, Int{}))); + using SmemLayoutVtSparseTrans = decltype(composition( + SmemLayoutVtSparse{}, + make_ordered_layout(product_each(shape(SmemLayoutVSparse{})), Step<_2, _1, _3>{}))); + using SmemLayoutDivideVtSparse = + decltype(tiled_divide(SmemLayoutVtSparseTrans{}, TransposeShapeAtom_{})); + using SmemShapeSTSM = Shape, Shape<_16, _4>>; + using FactoringShapeVtSparse = decltype(make_shape( + SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVtSparse{}), shape<2>(SmemLayoutDivideVtSparse{}), + shape<3>(SmemLayoutDivideVtSparse{}))); + using SmemLayoutVtSparseTransposeTgt = + decltype(composition(SmemLayoutDivideVtSparse{}, make_layout(FactoringShapeVtSparse{}))); }; /* - In-kernel Transpose of smemV into smemVt with ldmatrix.trans & stmatrix. - Note that all magic number corresponds to the /quantization/kernel_traits.cuh setup. - This transpose is not a general transpose, but a specific one for the FP8 MMA_PV: - 1. K-dimension: (2,2,4,4):(1,8,2,16), which adheres to the accum_P's layout - 2. N-dimension: (8,2,4):(2,1,16), which needs repermutation when rmemO -> smemO + FA3-style in-kernel transpose of smemV (MN-major) into smemVt (K-major) using LDSM.T & STSM. + Both tensors have the same shape (HEAD_DIM, CTA_KV, STAGES), only different swizzle patterns. + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp + + This is used for TMA path where V is loaded with transposed gmem strides. */ template struct SmemTransposeFP8_64x64 { using Element = typename Ktraits::DTypeKV; - using SmemLayoutVTransposeSrc = typename Ktraits::SmemLayoutVTransposeSrc; - using SmemLayoutVtTransposeTgt = typename Ktraits::SmemLayoutVtTransposeTgt; + using VTranposeTraits = typename Ktraits::VTranposeTraits; + using SmemLayoutVtTma = typename Ktraits::SmemLayoutV; + using SmemLayoutVtMma = typename Ktraits::SmemLayoutVt; + static_assert(cutlass::sizeof_bits_v == 8); + + using S2RTiledCopyVt = typename VTranposeTraits::S2RTiledCopyVt; + using R2STiledCopyV = typename VTranposeTraits::R2STiledCopyV; + using LDSM_divide_shape = typename VTranposeTraits::LDSM_divide_shape; + using STSM_divide_shape = typename VTranposeTraits::STSM_divide_shape; + + S2RTiledCopyVt s2r_tiled_copy_vt; + R2STiledCopyV r2s_tiled_copy_v; + + template + CUTLASS_DEVICE void do_transpose(SmemTensorVt& sVt, SmemTensorV& sV, int stage_idx) { + using namespace cute; + + auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(threadIdx.x); + auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(threadIdx.x); + + // flat_divide sVt (source, MN-major) and sV (target, K-major) for transpose + // sVt shape: (HEAD_DIM, CTA_KV, STAGES) + // After flat_divide: (LDSM_divide_shape, HEAD_DIM / LDSM_divide_shape[0], CTA_KV / + // LDSM_divide_shape[1], STAGES) + Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); + Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); + + // Use ILP=2 for better instruction-level parallelism + static constexpr int Transpose_ILP = + (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; + Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), + Shape>{}); + Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), + Shape>{}); + +#pragma unroll + for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { + Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{})); + static_assert(size<0>(tTransrV) == 16); + Tensor tTransrV_64 = recast(tTransrV); + + // Load from MN-major smem using LDSM.T + cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage_idx), tTransrV); + +// Byte permutation for FP8 element reordering +#pragma unroll + for (int j = 0; j < size(tTransrV_64); ++j) { + uint32_t upper = tTransrV_64[j].x; + uint32_t lower = tTransrV_64[j].y; + tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); + tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); + } + + // Store to K-major smem using STSM + cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage_idx)); + } + + // Sync all WG threads for ldmatrix completion + cutlass::arch::NamedBarrier::sync(Ktraits::NUM_PRODUCER_THREADS, + static_cast(NamedBarriers::kProducerWG) /*id*/); + } + + // Legacy interface for backward compatibility + using SmemLayoutVTransposeSrc = SmemLayoutVtTma; + using SmemLayoutVtTransposeTgt = SmemLayoutVtMma; +}; + +/* + Original FP8 transpose for sparse path (cp.async loading). + V is loaded in original (N, D) layout via cp.async, so smemV has shape (CTA_KV, HEAD_DIM, STAGES). + Transpose to smemVt with shape (HEAD_DIM, CTA_KV, STAGES) for MMA consumption. +*/ +template +struct SmemTransposeFP8_64x64_Sparse { + using Element = typename Ktraits::DTypeKV; + using SmemLayoutVSparseTransposeSrc = typename Ktraits::SmemLayoutVSparseTransposeSrc; + using SmemLayoutVtSparseTransposeTgt = typename Ktraits::SmemLayoutVtSparseTransposeTgt; static_assert(cutlass::sizeof_bits_v == 8); using ldsm_thread_shape = Shape<_4, _1, _8, _4>; @@ -151,9 +302,9 @@ struct SmemTransposeFP8_64x64 { template CUTLASS_DEVICE void do_transpose(SmemTensor& s_in, SmemTensorOut& s_out, int stage_idx) { CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < shape<2>(SmemLayoutVTransposeSrc{}); ++j) { + for (int j = 0; j < shape<2>(SmemLayoutVSparseTransposeSrc{}); ++j) { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < shape<1>(SmemLayoutVTransposeSrc{}); ++i) { + for (int i = 0; i < shape<1>(SmemLayoutVSparseTransposeSrc{}); ++i) { this->_tranpose(flatten(s_in(_, i, j, stage_idx)), flatten(s_out(_, i, j, stage_idx))); } } @@ -220,11 +371,18 @@ struct FP8AttentionKernelTraits { make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int{}))); using VTranposeTraits = TranposeTraits_64x64; + // TMA path layouts (FA3-style same-shape transpose) using SmemLayoutV = typename VTranposeTraits::SmemLayoutV; using SmemLayoutVt = typename VTranposeTraits::SmemLayoutVt; using SmemLayoutVTransposeSrc = typename VTranposeTraits::SmemLayoutVTransposeSrc; using SmemLayoutVtTransposeTgt = typename VTranposeTraits::SmemLayoutVtTransposeTgt; + // Sparse path layouts (original different-shape transpose) + using SmemLayoutVSparse = typename VTranposeTraits::SmemLayoutVSparse; + using SmemLayoutVtSparse = typename VTranposeTraits::SmemLayoutVtSparse; + using SmemLayoutVSparseTransposeSrc = typename VTranposeTraits::SmemLayoutVSparseTransposeSrc; + using SmemLayoutVtSparseTransposeTgt = typename VTranposeTraits::SmemLayoutVtSparseTransposeTgt; + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); @@ -236,9 +394,13 @@ struct FP8AttentionKernelTraits { using PipelineState = typename cutlass::PipelineState; // Modify SharedStorage + // NOTE: Use SmemLayoutVSparse for SharedStorage to ensure sparse (paged KV) path works correctly. + // SmemLayoutVSparse has shape (CTA_KV, HEAD_DIM, STAGES) which matches cp.async loading pattern. + // For TMA path, we create the tensor with SmemLayoutV (FA3-style) layout in mainloop_load.cuh. + // Both layouts have the same cosize, so memory allocation is identical. using SharedStorage = SharedStorageQKVOVt; + SmemLayoutK, SmemLayoutVSparse, SmemLayoutO>; }; } // namespace flashinfer diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh index 988f7e9aca..ed90c64eb9 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_load.cuh @@ -48,6 +48,10 @@ struct FP8CollectiveMainloop { using StrideT = cute::Shape; // (N, D, H) using LayoutT = cute::Layout; + // Transposed stride for V TMA loading: (D, N, H) instead of (N, D, H) + // This loads V^T directly into MN-major smem layout + using StrideVTransposed = cute::Shape<_1, int64_t, int64_t>; // (D, N, H) + using ShapeLseT = cute::Shape; using StrideLseT = cute::Shape<_1, int64_t>; using LayoutLseT = cute::Layout; @@ -64,11 +68,15 @@ struct FP8CollectiveMainloop { repeat_like(StrideT{}, int32_t(0)), StrideT{}), take<0, 2>(SmemLayoutK{}), select<1, 2>(TileShape_QKD{}), _1{})); // no mcast + // FA3-style: TMA loads V with transposed gmem strides into SmemLayoutV (MN-major) + // Gmem V has shape (N, D, H), we load with transposed strides to get V^T into (D, N) smem tiles + // SmemLayoutV now has shape (HEAD_DIM, CTA_KV, STAGES) = (D, N, STAGES) + // Tile shape for V TMA: select<2, 1>(TileShape_QKD{}) = (HEAD_DIM, CTA_KV) using TMA_V = decltype(make_tma_copy( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), - repeat_like(StrideT{}, int32_t(0)), StrideT{}), - take<0, 2>(SmemLayoutV{}), select<1, 2>(TileShape_QKD{}), _1{})); // no mcast + repeat_like(StrideVTransposed{}, int32_t(0)), StrideVTransposed{}), + take<0, 2>(SmemLayoutV{}), select<2, 1>(TileShape_QKD{}), _1{})); // no mcast static constexpr bool USE_TMA_LOAD_KV = true; using MainloopPipeline = typename Ktraits::MainloopPipeline; @@ -121,9 +129,19 @@ struct FP8CollectiveMainloop { Tensor mK = make_tensor(make_gmem_ptr(args.K_ptr), args.layout_K); TMA_K tma_load_K = make_tma_copy(GmemTiledCopyKV{}, mK, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{}); // no mcast - Tensor mV = make_tensor(make_gmem_ptr(args.V_ptr), args.layout_V); + + // FA3-style: Create V tensor with transposed strides for TMA loading + // Original V layout: (N, D, H) with strides (stride_N, 1, stride_H) + // Transposed V layout for TMA: (D, N, H) with strides (1, stride_N, stride_H) + auto [shape_N, shape_D, shape_H] = args.layout_V.shape(); + auto [stride_N, stride_D, stride_H] = args.layout_V.stride(); + auto shape_V_transposed = make_shape(shape_D, shape_N, shape_H); + auto stride_V_transposed = make_stride(stride_D, stride_N, stride_H); + Tensor mV = make_tensor(make_gmem_ptr(args.V_ptr), + make_layout(shape_V_transposed, stride_V_transposed)); TMA_V tma_load_V = make_tma_copy(GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _0{}), - select<1, 2>(TileShape_QKD{}), _1{}); // no mcast + select<2, 1>(TileShape_QKD{}), _1{}); // no mcast + return {args.layout_Q, args.layout_K, args.layout_V, tma_load_Q, tma_load_K, tma_load_V, args.window_left, args.additional_params}; } @@ -161,20 +179,24 @@ struct FP8CollectiveMainloop { BlockCoord const& block_coord, int work_idx) { Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + // sV now uses SmemLayoutV which is SmemLayoutVtTma (MN-major, shape (HEAD_DIM, CTA_KV, STAGES)) Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); - Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); - // *** Prepare In-kernel V Transpose *** - using SmemLayoutVTransposeSrc = typename Ktraits::SmemLayoutVTransposeSrc; - using SmemLayoutVtTransposeTgt = typename Ktraits::SmemLayoutVtTransposeTgt; + // FA3-style: mV uses transposed shape (D, N, H) instead of (N, D, H) + auto [shape_N, shape_D, shape_H] = mainloop_params.layout_V.shape(); + auto shape_V_transposed = make_shape(shape_D, shape_N, shape_H); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(shape_V_transposed); - Tensor sV_src = as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVTransposeSrc{})); + // *** Prepare In-kernel V Transpose *** + // FA3-style: sVt_src (MN-major) is the TMA destination, sVt_tgt (K-major) is the MMA source + // Both have the same shape (HEAD_DIM, CTA_KV, STAGES), only different swizzle patterns + Tensor sVt_src = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{})); Tensor sVt_tgt = as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), SmemLayoutVtTransposeTgt{})); + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), SmemLayoutVt{})); auto v_tranposer = SmemTransposeFP8_64x64(); auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = @@ -185,8 +207,9 @@ struct FP8CollectiveMainloop { qo_len)(_, _, q_tile_idx); // (Q, D) Tensor gK = get_local_tile_tensor(mK, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, kv_len); // (K, D, _) - Tensor gV = get_local_tile_tensor(mV, select<1, 2>(TileShape_QKD{}), kv_head_idx, kv_indptr, - kv_len); // (K, D, _) + // FA3-style: gV uses transposed tile shape (HEAD_DIM, CTA_KV) = select<2, 1>(TileShape_QKD{}) + Tensor gV = get_local_tile_tensor(mV, select<2, 1>(TileShape_QKD{}), kv_head_idx, kv_indptr, + kv_len); // (D, K, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); @@ -246,7 +269,7 @@ struct FP8CollectiveMainloop { pipeline_v.consumer_wait(smem_pipe_read); pipeline_vt.producer_acquire(smem_pipe_write); - v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); + v_tranposer.do_transpose(sVt_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); pipeline_v.consumer_release(smem_pipe_read); ++smem_pipe_read; @@ -271,7 +294,7 @@ struct FP8CollectiveMainloop { pipeline_v.consumer_wait(smem_pipe_read); pipeline_vt.producer_acquire(smem_pipe_write); - v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); + v_tranposer.do_transpose(sVt_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); pipeline_v.consumer_release(smem_pipe_read); ++smem_pipe_read; @@ -293,7 +316,7 @@ struct FP8CollectiveMainloop { } pipeline_v.consumer_wait(smem_pipe_read); pipeline_vt.producer_acquire(smem_pipe_write); - v_tranposer.do_transpose(sV_src, sVt_tgt, smem_pipe_read.index()); + v_tranposer.do_transpose(sVt_src, sVt_tgt, smem_pipe_read.index()); pipeline_vt.producer_commit(smem_pipe_write); pipeline_v.consumer_release(smem_pipe_read); ++smem_pipe_read; diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh index 9720af575e..4ddf792d97 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh @@ -14,10 +14,12 @@ namespace flashinfer { -template +// SmemLayoutVt_ template parameter allows mainloop to specify the correct Vt layout +// TMA path uses SmemLayoutVtMma (FA3-style), sparse path uses SmemLayoutVtSparse (original) +template CUTLASS_DEVICE void mma_fp8(const Params& mainloop_params, AttentionVariant& variant, MainloopPipeline pipeline_k, MainloopPipelineVt pipeline_vt, PipelineState& smem_pipe_read_k, PipelineState& smem_pipe_read_v, @@ -35,7 +37,7 @@ CUTLASS_DEVICE void mma_fp8(const Params& mainloop_params, AttentionVariant& var using SmemLayoutQ = typename Ktraits::SmemLayoutQ; using SmemLayoutK = typename Ktraits::SmemLayoutK; using SmemLayoutV = typename Ktraits::SmemLayoutV; - using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + using SmemLayoutVt = SmemLayoutVt_; // Use the layout passed from mainloop static_assert(is_rmem::value, "O tensor must be rmem resident."); static constexpr int CTA_Q = get<0>(TileShape_QKD{}); diff --git a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh index cc7f6cd0ad..2ca7f63dc6 100644 --- a/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh +++ b/include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh @@ -61,8 +61,10 @@ struct FP8SparseCollectiveMainloop { using SmemLayoutQ = typename Ktraits::SmemLayoutQ; using SmemLayoutK = typename Ktraits::SmemLayoutK; - using SmemLayoutV = typename Ktraits::SmemLayoutV; - using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + // Sparse path: use SmemLayoutVSparse which has shape (CTA_KV, HEAD_DIM, STAGES) for cp.async + // loading + using SmemLayoutV = typename Ktraits::SmemLayoutVSparse; + using SmemLayoutVt = typename Ktraits::SmemLayoutVtSparse; using ShapeT = cute::Shape; using StrideT = cute::Shape; // (N, D, H) @@ -189,14 +191,15 @@ struct FP8SparseCollectiveMainloop { Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); // *** Prepare In-kernel V Transpose *** - using SmemLayoutVTransposeSrc = typename Ktraits::SmemLayoutVTransposeSrc; - using SmemLayoutVtTransposeTgt = typename Ktraits::SmemLayoutVtTransposeTgt; + // Sparse path: use SmemLayoutVSparseTransposeSrc/Tgt for original (N, D) -> (D, N) transpose + using SmemLayoutVSparseTransposeSrc = typename Ktraits::SmemLayoutVSparseTransposeSrc; + using SmemLayoutVtSparseTransposeTgt = typename Ktraits::SmemLayoutVtSparseTransposeTgt; Tensor sV_src = as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVTransposeSrc{})); - Tensor sVt_tgt = as_position_independent_swizzle_tensor( - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), SmemLayoutVtTransposeTgt{})); - auto v_tranposer = SmemTransposeFP8_64x64(); + make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVSparseTransposeSrc{})); + Tensor sVt_tgt = as_position_independent_swizzle_tensor(make_tensor( + make_smem_ptr(shared_storage.smem_vt.data()), SmemLayoutVtSparseTransposeTgt{})); + auto v_tranposer = SmemTransposeFP8_64x64_Sparse(); /* ----- V Transpose ---- */ auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = diff --git a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh index 27733f90f7..8199933fe8 100644 --- a/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh @@ -233,7 +233,8 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp q_tile_idx, qo_len, kv_len); } - mma_fp8( mainloop_params, variant, pipeline_k, pipeline_vt, smem_pipe_read_k, smem_pipe_read_v, tOrO, attention_updater, num_kv_tiles, swa_begin_kv_tile_idx, swa_end_kv_tile_idx, @@ -259,7 +260,8 @@ cudaError_t SingleFP8PrefillWithKVCacheKernelTraitsDispatched(Params& params, cu using CollectiveMainloop = FP8CollectiveMainloop; using CollectiveEpilogue = FP8CollectiveEpilogue; - using Scheduler = SingleTileScheduler; + // Use LPT scheduling for causal attention for better load balancing + using Scheduler = SingleTileScheduler; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {params.q_ptr, get_gmem_layout(params.qo_len, params.num_qo_heads, KernelTraits::HEAD_DIM, diff --git a/include/flashinfer/attention/hopper/sm90_pipeline_no_cluster.cuh b/include/flashinfer/attention/hopper/sm90_pipeline_no_cluster.cuh new file mode 100644 index 0000000000..8c22dab109 --- /dev/null +++ b/include/flashinfer/attention/hopper/sm90_pipeline_no_cluster.cuh @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri + * Dao. Licensed under the BSD 3-Clause. + * + * Modified by the FlashInfer team. + */ +#ifndef FLASHINFER_ATTENTION_HOPPER_SM90_PIPELINE_NO_CLUSTER_CUH_ +#define FLASHINFER_ATTENTION_HOPPER_SM90_PIPELINE_NO_CLUSTER_CUH_ + +#include + +namespace flashinfer { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads +// signaling the barrier during consumer_release. This causes a perf regression in FA3 +// forward pass (especially hdim 128 causal). We instead reimplement the version of +// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. +// +// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +template > +class PipelineTmaAsyncNoCluster : public Base { + public: + using FullBarrier = typename Base::FullBarrier; + using EmptyBarrier = typename Base::EmptyBarrier; + static constexpr uint32_t Stages = Stages_; + using PipelineState = typename Base::PipelineState; + + using SharedStorage = typename Base::SharedStorage; + using ThreadCategory = typename Base::ThreadCategory; + using Params = typename Base::Params; + + static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_warpgroups_per_cluster = + (params.num_consumers + cutlass::NumThreadsPerWarpGroup - 1) / + cutlass::NumThreadsPerWarpGroup; + uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned< + decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, + multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, + ClusterShape cluster_shape, InitBarriers = {}, + InitMasks = {}) + : Base(storage, params, make_shape(_1{}, _1{}, _1{}) /*cluster_shape*/, + cute::false_type{} /*init_barriers*/, cute::false_type{} /*init_masks*/), + empty_barrier_ptr_(&storage.empty_barrier_[0]) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + int lane_predicate = cute::elect_one_sync(); + + static_assert(cute::is_same_v || + cute::is_same_v); + static_assert(cute::is_same_v || + cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params); + } + } + + // Constructor + template + CUTLASS_DEVICE PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, + ClusterShape cluster_shape) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, cute::true_type{}, + cute::true_type{}) {} + + template + CUTLASS_DEVICE PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, + ClusterShape cluster_shape, InitBarriers = {}) + : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, InitBarriers{}, + cute::true_type{}) {} + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { consumer_release(state.index()); } + + private: + EmptyBarrier* const empty_barrier_ptr_ = nullptr; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(0 /*dst_blockid_*/, + uint32_t(threadIdx.x % cutlass::NumThreadsPerWarpGroup == 0) & + (!skip) /*is_signaling_thread*/); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flashinfer + +#endif // FLASHINFER_ATTENTION_HOPPER_SM90_PIPELINE_NO_CLUSTER_CUH_ diff --git a/include/flashinfer/attention/hopper/tile_scheduler.cuh b/include/flashinfer/attention/hopper/tile_scheduler.cuh index 51a322346f..3f718a39d9 100644 --- a/include/flashinfer/attention/hopper/tile_scheduler.cuh +++ b/include/flashinfer/attention/hopper/tile_scheduler.cuh @@ -13,6 +13,10 @@ namespace flashinfer { +// LPT: Longest-Processing-Time-First scheduling for causal attention +// When LPT=true, block indices are reversed so that tiles with more KV tokens +// (higher indices in causal case) are processed first for better load balancing +template struct SingleTileScheduler { public: // Host side kernel arguments @@ -23,12 +27,13 @@ struct SingleTileScheduler { // Device side kernel params struct Params { + int const num_qo_tiles; // needed for LPT reversal int const qo_len, kv_len; cutlass::FastDivmod group_size_fastdiv; }; static Params to_underlying_arguments(Arguments const& args) { - return {args.qo_len, args.kv_len, args.group_size_fastdiv}; + return {args.num_qo_tiles, args.qo_len, args.kv_len, args.group_size_fastdiv}; } static dim3 get_grid_dim(Arguments const& args, int num_sm) { @@ -58,7 +63,12 @@ struct SingleTileScheduler { WorkTileInfo get_initial_work(Params const& params) const { int qo_head_idx = blockIdx.y; int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx); - return {/*q_tile_idx=*/int(blockIdx.x), qo_head_idx, kv_head_idx, /*is_valid_tile*/ true}; + int q_tile_idx = int(blockIdx.x); + // LPT: reverse block index for better load balancing in causal attention + if constexpr (LPT) { + q_tile_idx = params.num_qo_tiles - 1 - q_tile_idx; + } + return {q_tile_idx, qo_head_idx, kv_head_idx, /*is_valid_tile*/ true}; } CUTLASS_DEVICE diff --git a/include/flashinfer/attention/hopper/utils.cuh b/include/flashinfer/attention/hopper/utils.cuh index 8aeb5b1944..e1b85a9260 100644 --- a/include/flashinfer/attention/hopper/utils.cuh +++ b/include/flashinfer/attention/hopper/utils.cuh @@ -141,6 +141,30 @@ CUTLASS_DEVICE void permute_regs_A_to_C(Fragment& accum) { } } +// Permute output registers for FP8 kernel before writing to smem +// This undoes the register permutation from the FP8 MMA +// out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits (float) +template +CUTLASS_DEVICE void permute_output_fp8(Fragment& out) { + static_assert(decltype(size<0, 0>(out))::value == 2); + static_assert(decltype(size<0, 1>(out))::value == 2); + static_assert(decltype(size<0, 2>(out))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(out))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N)) +#pragma unroll + for (int mi = 0; mi < size<1>(frag); ++mi) { +#pragma unroll + for (int j = 0; j < size<0, 1>(frag); ++j) { +#pragma unroll + for (int i = 0; i < size<0, 2>(frag) / 2; ++i) { + cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), + frag(make_coord(_0{}, j, 2 * i + 1), mi)); + } + } + } +} + template __forceinline__ __device__ auto convert_type(Tensor const& tensor) { using From_type = typename Engine::value_type; diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 286023e204..de1b53cf77 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -796,7 +796,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i return cudaSuccess; } -inline float cost_function(int qo_len, int kv_len) { return 2 * float(qo_len) + kv_len; } +inline float cost_function(int qo_len, int kv_len) { return 0.05 * float(qo_len) + kv_len; } template std::vector flatten(const std::vector>& vec, int size_after_flatten) { @@ -902,8 +902,13 @@ inline cudaError_t PrefillSM90Plan( std::sort(idx_qo_kv_len_vec.begin(), idx_qo_kv_len_vec.end(), [](const auto& a, const auto& b) { return std::get<2>(a) > std::get<2>(b); }); int cta_tile_q = 128; + int cta_tile_kv = 128; if (head_dim_vo == 64) { cta_tile_q = 192; + } else if (head_dim_qk == 128 && head_dim_vo == 128 && !causal) { + cta_tile_kv = 192; + } else if (head_dim_qk > 128) { + cta_tile_kv = 64; } int device = 0; @@ -924,25 +929,69 @@ inline cudaError_t PrefillSM90Plan( int max_num_works_per_head = ceil_div(total_num_rows, cta_tile_q) + batch_size - 1; plan_info.same_schedule_for_all_heads = max_num_works_per_head > 4096; - for (int qo_head_idx = 0; - qo_head_idx < (plan_info.same_schedule_for_all_heads ? 1 : num_qo_heads); ++qo_head_idx) { + // L2-aware scheduling: compute swizzle size based on L2 cache capacity + // Group adjacent heads together so K/V can be reused in L2 cache + // Use conservative L2 size estimate (8MB like FA3, not full 50MB) + constexpr int64_t size_l2 = 8 * 1024 * 1024; // 8 MB (conservative, FA3 uses this) + + // Compute max KV blocks across all batches + int64_t max_kv_blocks = 1; + for (uint32_t i = 0; i < batch_size; ++i) { + int64_t kv_blocks = ceil_div(int64_t(kv_len_arr_h[i]), int64_t(cta_tile_kv)); + max_kv_blocks = std::max(max_kv_blocks, kv_blocks); + } + + // Size of one KV block: cta_tile_kv * (head_dim_qk + head_dim_vo) * sizeof(half) + int64_t size_one_kv_block = cta_tile_kv * (head_dim_qk + head_dim_vo) * 2; // 2 bytes for FP16 + int64_t max_kv_blocks_in_l2 = size_l2 / size_one_kv_block; + + // FA3-style: use stepped values (16, 8, 4, 2, 1) based on how many KV heads fit + int nheads_in_l2 = max_kv_blocks * 16 <= max_kv_blocks_in_l2 ? 16 + : max_kv_blocks * 8 <= max_kv_blocks_in_l2 ? 8 + : max_kv_blocks * 4 <= max_kv_blocks_in_l2 ? 4 + : max_kv_blocks * 2 <= max_kv_blocks_in_l2 ? 2 + : 1; + + // Scale by GQA group size (num_qo_heads / num_kv_heads) + int group_size = num_qo_heads / num_kv_heads; + int swizzle = nheads_in_l2 * group_size; + // Clamp swizzle to valid range + swizzle = std::max(1, std::min(swizzle, int(num_qo_heads))); + + // Schedule tiles in L2-aware order: process heads in groups of 'swizzle' + // Within each section, iterate over q_tiles first (LPT order), then heads + // This matches FA3's traversal order for better L2 cache utilization + int num_sections = ceil_div(int(num_qo_heads), swizzle); + + for (int section = 0; section < (plan_info.same_schedule_for_all_heads ? 1 : num_sections); + ++section) { + int head_start = section * swizzle; + int head_end = plan_info.same_schedule_for_all_heads + ? 1 + : std::min(head_start + swizzle, int(num_qo_heads)); + int nheads_in_section = head_end - head_start; + + // Within each section, iterate over q_tiles first (LPT: from last to first), then heads + // This allows adjacent heads to reuse K/V in L2 cache for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) { int num_qo_tiles = ceil_div(qo_len, cta_tile_q); for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { - auto [cta_idx, accum_cost] = cta_cost_heap.pop(); - // NOTE(Zihao): our current FA3 implementation do not fuse query and group heads - // so the group_size in cost_function is always 1 - int effective_kv_len = - causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cta_tile_q, num_qo_tiles, 1) - : kv_len; - cta_cost_heap.insert({cta_idx, accum_cost + cost_function(cta_tile_q, effective_kv_len)}); - cta_qo_tile_indices[cta_idx].push_back(qo_tile_idx); - cta_qo_indptr[cta_idx].push_back(qo_indptr_h[i]); - cta_qo_len[cta_idx].push_back(qo_len); - cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]); - cta_kv_len[cta_idx].push_back(kv_len); - cta_head_indices[cta_idx].push_back(qo_head_idx); - cta_batch_indices[cta_idx].push_back(i); + for (int qo_head_idx = head_start; qo_head_idx < head_end; ++qo_head_idx) { + auto [cta_idx, accum_cost] = cta_cost_heap.pop(); + // NOTE(Zihao): our current FA3 implementation do not fuse query and group heads + // so the group_size in cost_function is always 1 + int effective_kv_len = causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, + cta_tile_q, num_qo_tiles, 1) + : kv_len; + cta_cost_heap.insert({cta_idx, accum_cost + cost_function(cta_tile_q, effective_kv_len)}); + cta_qo_tile_indices[cta_idx].push_back(qo_tile_idx); + cta_qo_indptr[cta_idx].push_back(qo_indptr_h[i]); + cta_qo_len[cta_idx].push_back(qo_len); + cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]); + cta_kv_len[cta_idx].push_back(kv_len); + cta_head_indices[cta_idx].push_back(qo_head_idx); + cta_batch_indices[cta_idx].push_back(i); + } } } }