From 7e24a046ce93ce8266c005d2ee3315de7d532272 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Tue, 4 Nov 2025 11:36:21 -0800 Subject: [PATCH 1/4] flashinfer-fp8-cutlass-fmha --- benchmarks/bench_blackwell_attention.py | 51 +++++++++++++++---- csrc/fmha_cutlass_sm100.cu | 12 ++++- csrc/fmha_cutlass_sm100_binding.cu | 3 +- flashinfer/prefill.py | 10 ++-- ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 7 +-- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 4 +- 6 files changed, 63 insertions(+), 24 deletions(-) diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index 73b0cd0b3c..ff0e549b87 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -32,15 +32,27 @@ def bench_fmha_blackwell( causal, dtype, ): - q = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) - k = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) - v = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) + # if sizeof(dtype) == 1, create randn from half and then convert to dtype + if dtype.itemsize == 1: + q = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=torch.half, device="cuda" + ).to(dtype) + k = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=torch.half, device="cuda" + ).to(dtype) + v = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=torch.half, device="cuda" + ).to(dtype) + else: + q = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + k = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + v = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) qo_segment_offsets = ( torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len @@ -87,7 +99,7 @@ def bench_fmha_blackwell( if __name__ == "__main__": - print("\n === head_dim=128 ===") + print("\n === head_dim=128, bfloat16 ===") bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16) bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16) bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16) @@ -106,6 +118,25 @@ def bench_fmha_blackwell( bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16) bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16) + print("\n === head_dim=128, float8_e4m3fn ===") + bench_fmha_blackwell(128, 512, 32, 128, False, torch.float8_e4m3fn) + bench_fmha_blackwell(64, 1024, 32, 128, False, torch.float8_e4m3fn) + bench_fmha_blackwell(32, 2048, 32, 128, False, torch.float8_e4m3fn) + bench_fmha_blackwell(16, 4096, 32, 128, False, torch.float8_e4m3fn) + bench_fmha_blackwell(8, 8192, 32, 128, False, torch.float8_e4m3fn) + bench_fmha_blackwell(4, 16384, 32, 128, False, torch.float8_e4m3fn) + bench_fmha_blackwell(2, 32768, 32, 128, False, torch.float8_e4m3fn) + bench_fmha_blackwell(1, 65536, 32, 128, False, torch.float8_e4m3fn) + + bench_fmha_blackwell(128, 512, 32, 128, True, torch.float8_e4m3fn) + bench_fmha_blackwell(64, 1024, 32, 128, True, torch.float8_e4m3fn) + bench_fmha_blackwell(32, 2048, 32, 128, True, torch.float8_e4m3fn) + bench_fmha_blackwell(16, 4096, 32, 128, True, torch.float8_e4m3fn) + bench_fmha_blackwell(8, 8192, 32, 128, True, torch.float8_e4m3fn) + bench_fmha_blackwell(4, 16384, 32, 128, True, torch.float8_e4m3fn) + bench_fmha_blackwell(2, 32768, 32, 128, True, torch.float8_e4m3fn) + bench_fmha_blackwell(1, 65536, 32, 128, True, torch.float8_e4m3fn) + print("\n === head_dim=64 ===") bench_fmha_blackwell(128, 512, 32, 64, False, torch.bfloat16) bench_fmha_blackwell(64, 1024, 32, 64, False, torch.bfloat16) diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index 08f1235adf..47ca0b4057 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -58,6 +58,11 @@ using tvm::ffi::Optional; using c_type_out = c_type_in; \ return __VA_ARGS__(); \ }); \ + } else { \ + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ + using c_type_out = nv_bfloat16; \ + return __VA_ARGS__(); \ + }); \ } \ return false; \ }() @@ -80,14 +85,17 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff ffi::TensorView qo_tile_indices, ffi::TensorView qo_head_indices, ffi::TensorView batch_indices, ffi::TensorView o, Optional maybe_lse, int64_t mask_mode_code, - double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len) { + double sm_scale, int64_t max_qo_len) { TVM_FFI_ICHECK_EQ(q.dtype(), k.dtype()); auto scalar_type_in = q.dtype(); auto scalar_type_out = o.dtype(); MaskMode mask_mode = static_cast(mask_mode_code); int total_qo_len = q.size(0); int total_kv_len = k.size(0); + int num_qo_heads = q.size(1); + int num_kv_heads = k.size(1); + int head_dim_qk = q.size(2); + int head_dim_vo = v.size(2); int batch_size = qo_segment_offsets.size(0) - 1; int q_stride_n = q.stride(0); int q_stride_h = q.stride(1); diff --git a/csrc/fmha_cutlass_sm100_binding.cu b/csrc/fmha_cutlass_sm100_binding.cu index ddb3b8d9cd..69fa341b72 100644 --- a/csrc/fmha_cutlass_sm100_binding.cu +++ b/csrc/fmha_cutlass_sm100_binding.cu @@ -22,8 +22,7 @@ void FMHACutlassSM100Run(TensorView workspace_buffer, TensorView q, TensorView k TensorView work_indptr, TensorView qo_tile_indices, TensorView qo_head_indices, TensorView batch_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, double sm_scale, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, - int64_t head_dim_vo, int64_t max_qo_len); + int64_t max_qo_len); void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets, TensorView work_indptr, TensorView qo_tile_indices, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 41fac0e4e9..8d631ab80c 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2973,6 +2973,8 @@ def run( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) if out is None: + # when input dtype is fp8, we need to use bf16 output + out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( q.shape[:-1] + v.shape[-1:], dtype=self._cached_o_data_type, @@ -3223,12 +3225,14 @@ def fmha_varlen( ) = plan_info if out is None: + # when input dtype is fp8, we need to use bf16 output + out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( qo_total_len + max(max_qo_len, 128), num_qo_heads, head_dim_vo, device=q.device, - dtype=q.dtype, + dtype=out_dtype, )[max(max_qo_len, 128) :] if lse is None and return_lse: @@ -3251,10 +3255,6 @@ def fmha_varlen( lse, mask_mode_code, sm_scale, - num_qo_heads, - num_kv_heads, - head_dim_qk, - head_dim_vo, max_qo_len, ) diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 640b876b49..a973117548 100644 --- a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -64,9 +64,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { using Mask = Mask_; static constexpr int StageCountQ = 2; - static constexpr int StageCountKV = (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64) - ? 2 - : 1; // sizeof(Element_) == 1 ? 2 : 2; + static constexpr int StageCountKV = + (sizeof(Element_) == 1) + ? (get<2>(TileShapeQK{}) == 128 ? 4 : 2) + : (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64 ? 2 : 1); using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index d6e913a319..ebf38d7347 100644 --- a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -66,8 +66,8 @@ struct Sm100FmhaCtxKernelWarpspecializedSchedule { static const bool kDebugUsingPrintf = false; static const int NumRegsSoftmax = 192; - static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); - static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0); static const int NumRegsEmpty = 24; static const int NumWarps = 16; From 076732f08f4ffcc7d090af1c4159efa951665ce5 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Wed, 5 Nov 2025 17:03:50 -0800 Subject: [PATCH 2/4] update --- benchmarks/bench_blackwell_attention.py | 141 ++++++++++++------------ csrc/fmha_cutlass_sm100.cu | 2 +- tests/attention/test_blackwell_fmha.py | 125 +++++++++++++++++++++ 3 files changed, 197 insertions(+), 71 deletions(-) diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index ff0e549b87..3b48370394 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -14,6 +14,7 @@ limitations under the License. """ +import csv import numpy as np import torch @@ -27,31 +28,33 @@ def bench_fmha_blackwell( batch_size, qkv_len, - num_heads, - head_dim, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, causal, dtype, ): # if sizeof(dtype) == 1, create randn from half and then convert to dtype if dtype.itemsize == 1: q = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=torch.half, device="cuda" + batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda" ).to(dtype) k = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=torch.half, device="cuda" + batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=torch.half, device="cuda" ).to(dtype) v = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=torch.half, device="cuda" + batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=torch.half, device="cuda" ).to(dtype) else: q = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=dtype, device="cuda" ) k = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=dtype, device="cuda" ) v = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=dtype, device="cuda" ) qo_segment_offsets = ( @@ -68,10 +71,10 @@ def bench_fmha_blackwell( wrapper.plan( qo_segment_offsets, kv_segment_offsets, - num_heads, - num_heads, - head_dim, - head_dim_vo=head_dim, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, causal=causal, q_data_type=dtype, kv_data_type=dtype, @@ -95,63 +98,61 @@ def bench_fmha_blackwell( ) print( f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s" - ) - if __name__ == "__main__": - print("\n === head_dim=128, bfloat16 ===") - bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16) - - bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16) - - print("\n === head_dim=128, float8_e4m3fn ===") - bench_fmha_blackwell(128, 512, 32, 128, False, torch.float8_e4m3fn) - bench_fmha_blackwell(64, 1024, 32, 128, False, torch.float8_e4m3fn) - bench_fmha_blackwell(32, 2048, 32, 128, False, torch.float8_e4m3fn) - bench_fmha_blackwell(16, 4096, 32, 128, False, torch.float8_e4m3fn) - bench_fmha_blackwell(8, 8192, 32, 128, False, torch.float8_e4m3fn) - bench_fmha_blackwell(4, 16384, 32, 128, False, torch.float8_e4m3fn) - bench_fmha_blackwell(2, 32768, 32, 128, False, torch.float8_e4m3fn) - bench_fmha_blackwell(1, 65536, 32, 128, False, torch.float8_e4m3fn) - - bench_fmha_blackwell(128, 512, 32, 128, True, torch.float8_e4m3fn) - bench_fmha_blackwell(64, 1024, 32, 128, True, torch.float8_e4m3fn) - bench_fmha_blackwell(32, 2048, 32, 128, True, torch.float8_e4m3fn) - bench_fmha_blackwell(16, 4096, 32, 128, True, torch.float8_e4m3fn) - bench_fmha_blackwell(8, 8192, 32, 128, True, torch.float8_e4m3fn) - bench_fmha_blackwell(4, 16384, 32, 128, True, torch.float8_e4m3fn) - bench_fmha_blackwell(2, 32768, 32, 128, True, torch.float8_e4m3fn) - bench_fmha_blackwell(1, 65536, 32, 128, True, torch.float8_e4m3fn) - - print("\n === head_dim=64 ===") - bench_fmha_blackwell(128, 512, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 64, False, torch.bfloat16) - - bench_fmha_blackwell(128, 512, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 64, True, torch.bfloat16) + results = [] + + # Define configurations: (batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name) + # DeepSeek-R1 uses MLA (Multi-head Latent Attention) with 128 heads + # head_dim_qk=192 (128 nope + 64 rope), head_dim_vo=128 + configs = [ + (16, 512, 128, 128, 192, 128, "DeepSeek-R1"), + (8, 1024, 128, 128, 192, 128, "DeepSeek-R1"), + (4, 2048, 128, 128, 192, 128, "DeepSeek-R1"), + (2, 4096, 128, 128, 192, 128, "DeepSeek-R1"), + (1, 8192, 128, 128, 192, 128, "DeepSeek-R1"), + ] + + # Run benchmarks: Causal first, then non-causal + # For each config: bfloat16 then fp8 + for causal in [True, False]: + print(f"\n{'='*80}") + print(f"Running {'CAUSAL' if causal else 'NON-CAUSAL'} benchmarks") + print(f"{'='*80}") + + for batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name in configs: + # Run bfloat16 + print(f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16") + result_bf16 = bench_fmha_blackwell( + batch_size, qkv_len, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, causal, torch.bfloat16 + ) + result_bf16["config_name"] = config_name + results.append(result_bf16) + print(f" → {result_bf16['tflops']:.2f} TFLOPs/s, {result_bf16['time_ms']:.3f} ms") + + # Run fp8 + print(f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8") + result_fp8 = bench_fmha_blackwell( + batch_size, qkv_len, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, causal, torch.float8_e4m3fn + ) + result_fp8["config_name"] = config_name + results.append(result_fp8) + speedup = result_fp8['tflops'] / result_bf16['tflops'] + print(f" → {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)") + + # Write results to CSV + csv_filename = "/workspace/logs/fp8_attention_deepseek_benchmark.csv" + fieldnames = ["config_name", "batch_size", "qkv_len", "num_qo_heads", "num_kv_heads", + "head_dim_qk", "head_dim_vo", "causal", "dtype", "time_ms", "tflops"] + + with open(csv_filename, 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for result in results: + writer.writerow(result) + + print(f"\n{'='*80}") + print(f"Results saved to: {csv_filename}") + print(f"{'='*80}") diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index 47ca0b4057..bad8ece35a 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -59,7 +59,7 @@ using tvm::ffi::Optional; return __VA_ARGS__(); \ }); \ } else { \ - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ using c_type_out = nv_bfloat16; \ return __VA_ARGS__(); \ }); \ diff --git a/tests/attention/test_blackwell_fmha.py b/tests/attention/test_blackwell_fmha.py index 298bfa5db4..5b885a8589 100644 --- a/tests/attention/test_blackwell_fmha.py +++ b/tests/attention/test_blackwell_fmha.py @@ -347,6 +347,131 @@ def test_blackwell_cutlass_qo_kv_varlen( torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [1, 2, 9, 12]) +@pytest.mark.parametrize("qo_len", [177, 377]) +@pytest.mark.parametrize("kv_len", [544, 977]) +@pytest.mark.parametrize( + "num_qo_heads,num_kv_heads", + [ + (128, 128), # DeepSeek-R1 MHA (Multi-head Attention for Prefill) + ], +) +@pytest.mark.parametrize( + "head_dim_qk,head_dim_vo,sm_scale", + [ + (192, 128, 1.0 / math.sqrt(192)), # DeepSeek-R1: qk_nope(128) + qk_rope(64) = 192, v=128 + ], +) +@pytest.mark.parametrize("causal", [False, True]) +def test_blackwell_cutlass_fmha_fp8( + batch_size, + qo_len, + kv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + sm_scale, + causal, +): + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + if not is_sm100a_supported(torch.device("cuda")) and not is_sm110a_supported( + torch.device("cuda") + ): + pytest.skip("only SM100A and SM110A are supported on this device") + + torch.manual_seed(42) + dtype_in = torch.float8_e4m3fn + dtype_out = torch.bfloat16 + + # Create FP8 tensors by generating half precision then converting + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype_in) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype_in) + v = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_vo, dtype=torch.half, device="cuda" + ).to(dtype_in) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=sm_scale, + q_data_type=dtype_in, + kv_data_type=dtype_in, + ) + o, lse = wrapper.run(q, k, v, return_lse=True) + + # Verify output is bfloat16 + assert o.dtype == dtype_out, f"Expected output dtype {dtype_out}, got {o.dtype}" + + gqa_group_ratio = num_qo_heads // num_kv_heads + k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + + # Reference implementation with FP8 inputs, upcast to float32, output as bfloat16 + qo_len_ref = q.shape[0] // batch_size + kv_len_ref = k_repeated.shape[0] // batch_size + num_qo_heads_ref = q.shape[1] + head_dim_qk_ref = q.shape[2] + head_dim_vo_ref = v_repeated.shape[2] + + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len_ref, num_qo_heads_ref, head_dim_qk_ref).float(), + k_repeated.view(batch_size, kv_len_ref, num_qo_heads_ref, head_dim_qk_ref).float(), + ) + * sm_scale + ) + + if causal: + mask = torch.arange(kv_len_ref - qo_len_ref, kv_len_ref, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len_ref, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len_ref, kv_len_ref, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v_repeated.view(batch_size, kv_len_ref, num_qo_heads_ref, head_dim_vo_ref).float(), + ) + .contiguous() + .view(batch_size * qo_len_ref, num_qo_heads_ref, head_dim_vo_ref) + .to(dtype_out) # Convert to bfloat16 for FP8 output + ) + lse_ref = (lse_ref * math.log2(math.e)).flatten(0, 1) + + # FP8 has lower precision, use relaxed tolerances + torch.testing.assert_close(o, o_ref, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": test_blackwell_cutlass_fmha( 9, From 3bf708aeaa24323add4bd6afb698fb3de3e15460 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Mon, 8 Dec 2025 16:30:37 -0800 Subject: [PATCH 3/4] Add q/k/v/o_scale Signed-off-by: Pavani Majety --- benchmarks/bench_blackwell_attention.py | 103 +++++++++++------- csrc/fmha_cutlass_sm100.cu | 5 +- csrc/fmha_cutlass_sm100_binding.cu | 1 + flashinfer/prefill.py | 32 ++++++ .../blackwell/fmha_cutlass_sm100.cuh | 10 +- tests/attention/test_blackwell_fmha.py | 78 +++++++------ 6 files changed, 149 insertions(+), 80 deletions(-) diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index 3b48370394..780c994e2f 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -14,6 +14,7 @@ limitations under the License. """ +import argparse import csv import numpy as np import torch @@ -34,28 +35,20 @@ def bench_fmha_blackwell( head_dim_vo, causal, dtype, + o_data_type, ): - # if sizeof(dtype) == 1, create randn from half and then convert to dtype - if dtype.itemsize == 1: - q = torch.randn( - batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda" - ).to(dtype) - k = torch.randn( - batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=torch.half, device="cuda" - ).to(dtype) - v = torch.randn( - batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=torch.half, device="cuda" - ).to(dtype) - else: - q = torch.randn( - batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=dtype, device="cuda" - ) - k = torch.randn( - batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=dtype, device="cuda" - ) - v = torch.randn( - batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=dtype, device="cuda" - ) + # if sizeof(dtype) == 1 like with torch.float8_e4m3fn, + # create randn from half and then convert to dtype + init_dtype = torch.half if dtype.itemsize == 1 else dtype + q = torch.randn( + batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=init_dtype, device="cuda" + ).to(dtype) + k = torch.randn( + batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=init_dtype, device="cuda" + ).to(dtype) + v = torch.randn( + batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=init_dtype, device="cuda" + ).to(dtype) qo_segment_offsets = ( torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len @@ -68,6 +61,8 @@ def bench_fmha_blackwell( kv_layout="NHD", backend="cutlass", ) + # For FP8 input, output must be bfloat16 + o_data_type = torch.bfloat16 if dtype.itemsize == 1 else dtype wrapper.plan( qo_segment_offsets, kv_segment_offsets, @@ -78,6 +73,7 @@ def bench_fmha_blackwell( causal=causal, q_data_type=dtype, kv_data_type=dtype, + o_data_type=o_data_type, ) _o = wrapper.run(q, k, v) measurements = bench_gpu_time( @@ -90,16 +86,39 @@ def bench_fmha_blackwell( TFLOPS = attention_tflops_per_sec_with_actual_seq_lens( torch.full((batch_size,), qkv_len), torch.full((batch_size,), qkv_len), - head_dim, - head_dim, - num_heads, + head_dim_qk, + head_dim_vo, + num_qo_heads, causal, ms, ) print( - f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s" + f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim_qk={head_dim_qk}, head_dim_vo={head_dim_vo}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s" + ) + return { + "config_name": f"Blackwell-{config_name}", + "batch_size": batch_size, + "qkv_len": qkv_len, + "num_qo_heads": num_qo_heads, + "num_kv_heads": num_kv_heads, + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, + "causal": causal, + "dtype": dtype, + "time_ms": ms, + "tflops": TFLOPS, + } if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark FP8 attention for DeepSeek-R1") + parser.add_argument( + "--save-results-to", + type=str, + default=None, + help="Path to save benchmark results as CSV (optional)" + ) + args = parser.parse_args() + results = [] # Define configurations: (batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name) @@ -125,7 +144,8 @@ def bench_fmha_blackwell( print(f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16") result_bf16 = bench_fmha_blackwell( batch_size, qkv_len, num_qo_heads, num_kv_heads, - head_dim_qk, head_dim_vo, causal, torch.bfloat16 + head_dim_qk, head_dim_vo, causal, torch.bfloat16, + o_data_type=torch.bfloat16, ) result_bf16["config_name"] = config_name results.append(result_bf16) @@ -135,24 +155,25 @@ def bench_fmha_blackwell( print(f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8") result_fp8 = bench_fmha_blackwell( batch_size, qkv_len, num_qo_heads, num_kv_heads, - head_dim_qk, head_dim_vo, causal, torch.float8_e4m3fn + head_dim_qk, head_dim_vo, causal, torch.float8_e4m3fn, + o_data_type=torch.bfloat16, ) result_fp8["config_name"] = config_name results.append(result_fp8) speedup = result_fp8['tflops'] / result_bf16['tflops'] print(f" → {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)") - # Write results to CSV - csv_filename = "/workspace/logs/fp8_attention_deepseek_benchmark.csv" - fieldnames = ["config_name", "batch_size", "qkv_len", "num_qo_heads", "num_kv_heads", - "head_dim_qk", "head_dim_vo", "causal", "dtype", "time_ms", "tflops"] - - with open(csv_filename, 'w', newline='') as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - for result in results: - writer.writerow(result) - - print(f"\n{'='*80}") - print(f"Results saved to: {csv_filename}") - print(f"{'='*80}") + # Write results to CSV if requested + if args.save_results_to: + fieldnames = ["config_name", "batch_size", "qkv_len", "num_qo_heads", "num_kv_heads", + "head_dim_qk", "head_dim_vo", "causal", "dtype", "time_ms", "tflops"] + + with open(args.save_results_to, 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for result in results: + writer.writerow(result) + + print(f"\n{'='*80}") + print(f"Results saved to: {args.save_results_to}") + print(f"{'='*80}") diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index bad8ece35a..b5afc985a6 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -85,7 +85,8 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff ffi::TensorView qo_tile_indices, ffi::TensorView qo_head_indices, ffi::TensorView batch_indices, ffi::TensorView o, Optional maybe_lse, int64_t mask_mode_code, - double sm_scale, int64_t max_qo_len) { + double sm_scale, double scale_q, double scale_k, double scale_v, + double o_scale, int64_t max_qo_len) { TVM_FFI_ICHECK_EQ(q.dtype(), k.dtype()); auto scalar_type_in = q.dtype(); auto scalar_type_out = o.dtype(); @@ -128,7 +129,7 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff static_cast(qo_head_indices.data_ptr()), static_cast(batch_indices.data_ptr()), static_cast(o.data_ptr()), maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr, - mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, + mask_mode_code, sm_scale, scale_q, scale_k, scale_v, o_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, stream); TVM_FFI_ICHECK_EQ(status, cudaSuccess) diff --git a/csrc/fmha_cutlass_sm100_binding.cu b/csrc/fmha_cutlass_sm100_binding.cu index 69fa341b72..2668d97c53 100644 --- a/csrc/fmha_cutlass_sm100_binding.cu +++ b/csrc/fmha_cutlass_sm100_binding.cu @@ -22,6 +22,7 @@ void FMHACutlassSM100Run(TensorView workspace_buffer, TensorView q, TensorView k TensorView work_indptr, TensorView qo_tile_indices, TensorView qo_head_indices, TensorView batch_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, double sm_scale, + double scale_q, double scale_k, double scale_v, double o_scale, int64_t max_qo_len); void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 8d631ab80c..278954b404 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2908,6 +2908,10 @@ def run( k: torch.Tensor, v: torch.Tensor, *args, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, @@ -2926,6 +2930,14 @@ def run( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_vo]`` *args Additional arguments for the custom kernel. + q_scale: Optional[float] + The calibration scale of fp8 query, if not provided, will be set to ``1.0``. + k_scale: Optional[float] + The calibration scale of fp8 key, if not provided, will be set to ``1.0``. + v_scale: Optional[float] + The calibration scale of fp8 value, if not provided, will be set to ``1.0``. + o_scale: Optional[float] + The calibration scale of output, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] @@ -2998,6 +3010,10 @@ def run( plan_info=self._plan_info, causal=self._causal, sm_scale=sm_scale, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + o_scale=o_scale, max_qo_len=self._max_qo_len, out=out, lse=lse, @@ -3183,6 +3199,10 @@ def fmha_varlen( lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: workspace_buffer = _get_cache_buf( @@ -3207,6 +3227,14 @@ def fmha_varlen( mask_mode_code = 1 if causal else 0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim_qk) + if q_scale is None: + q_scale = 1.0 + if k_scale is None: + k_scale = 1.0 + if v_scale is None: + v_scale = 1.0 + if o_scale is None: + o_scale = 1.0 qo_total_len = nnz_qo if max_qo_len is None: @@ -3255,6 +3283,10 @@ def fmha_varlen( lse, mask_mode_code, sm_scale, + q_scale, + k_scale, + v_scale, + o_scale, max_qo_len, ) diff --git a/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh index 66d8b4fb91..877e9ac7d0 100644 --- a/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh +++ b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh @@ -76,7 +76,8 @@ struct FwdRunner { IdType* qo_segment_offsets, IdType* kv_segment_offsets, IdType* work_indptr, IdType* qo_tile_indices, IdType* qo_head_indices, IdType* batch_indices, DTypeOut* o, float* maybe_lse, int mask_mode_code, - double sm_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, + double sm_scale, double q_scale, double k_scale, double v_scale, double o_scale, + int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, int q_stride_n, int q_stride_h, int k_stride_n, int k_stride_h, int v_stride_n, int v_stride_h, int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, cudaStream_t stream) { @@ -120,7 +121,7 @@ struct FwdRunner { typename Operation::Arguments arguments{ problem_shape, - {q, layout_Q, k, layout_K, v, layout_V, sm_scale}, + {q, layout_Q, k, layout_K, v, layout_V, sm_scale, q_scale, k_scale, v_scale, o_scale}, {o - max_qo_len * get<0>(stride_O), layout_O, maybe_lse, layout_LSE, max_qo_len}, {work_indptr, qo_tile_indices, qo_head_indices, batch_indices}, hw_info}; @@ -163,14 +164,15 @@ cudaError_t run_fmha_fwd(void* workspace_buffer, DTypeIn* q, DTypeIn* k, DTypeIn IdType* qo_segment_offsets, IdType* kv_segment_offsets, IdType* work_indptr, IdType* qo_tile_indices, IdType* qo_head_indices, IdType* batch_indices, DTypeOut* o, float* maybe_lse, int mask_mode_code, - double sm_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, + double sm_scale, double q_scale, double k_scale, double v_scale, double o_scale, + int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, int q_stride_n, int q_stride_h, int k_stride_n, int k_stride_h, int v_stride_n, int v_stride_h, int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, cudaStream_t stream) { return FwdRunner::run( workspace_buffer, q, k, v, qo_segment_offsets, kv_segment_offsets, work_indptr, qo_tile_indices, qo_head_indices, batch_indices, o, maybe_lse, mask_mode_code, sm_scale, - num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, + q_scale, k_scale, v_scale, o_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, stream); } diff --git a/tests/attention/test_blackwell_fmha.py b/tests/attention/test_blackwell_fmha.py index 5b885a8589..130d77b699 100644 --- a/tests/attention/test_blackwell_fmha.py +++ b/tests/attention/test_blackwell_fmha.py @@ -419,6 +419,7 @@ def test_blackwell_cutlass_fmha_fp8( sm_scale=sm_scale, q_data_type=dtype_in, kv_data_type=dtype_in, + o_data_type=dtype_out, ) o, lse = wrapper.run(q, k, v, return_lse=True) @@ -473,37 +474,48 @@ def test_blackwell_cutlass_fmha_fp8( if __name__ == "__main__": - test_blackwell_cutlass_fmha( - 9, - 377, - 977, - 1, - 1, - 192, - 128, - 1, - False, - torch.bfloat16, - ) - - test_blackwell_cutlass_varlen( - [0, 1274, 2568, 3915, 5194, 6498, 7839, 8192], - 32, - 4, - 128, - 128, - 1, - True, - torch.bfloat16, - ) - - test_blackwell_cutlass_qo_kv_varlen( - [0, 10, 20, 30, 40, 50, 60, 100], - [0, 50, 50, 50, 50, 50, 50, 50], - 32, - 8, - 128, - 128, - 1, - torch.bfloat16, + test_blackwell_cutlass_fmha_fp8( + batch_size=9, + qo_len=377, + kv_len=977, + num_qo_heads=1, + num_kv_heads=1, + head_dim_qk=192, + head_dim_vo=128, + sm_scale=1, + causal=False, ) + # test_blackwell_cutlass_fmha( + # 9, + # 377, + # 977, + # 1, + # 1, + # 192, + # 128, + # 1, + # False, + # torch.bfloat16, + # ) + + # test_blackwell_cutlass_varlen( + # [0, 1274, 2568, 3915, 5194, 6498, 7839, 8192], + # 32, + # 4, + # 128, + # 128, + # 1, + # True, + # torch.bfloat16, + # ) + + # test_blackwell_cutlass_qo_kv_varlen( + # [0, 10, 20, 30, 40, 50, 60, 100], + # [0, 50, 50, 50, 50, 50, 50, 50], + # 32, + # 8, + # 128, + # 128, + # 1, + # torch.bfloat16, + # ) From 6697a973db2f1bc14cf25aa678154a11ad7a7d7d Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Mon, 8 Dec 2025 16:42:11 -0800 Subject: [PATCH 4/4] Add scales support and fix pre-commit Signed-off-by: Pavani Majety --- benchmarks/bench_blackwell_attention.py | 96 +++++++++++++------ csrc/fmha_cutlass_sm100.cu | 8 +- flashinfer/prefill.py | 10 +- .../blackwell/fmha_cutlass_sm100.cuh | 14 +-- tests/attention/test_blackwell_fmha.py | 28 ++++-- 5 files changed, 107 insertions(+), 49 deletions(-) diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index 780c994e2f..664cae9340 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -109,18 +109,21 @@ def bench_fmha_blackwell( "tflops": TFLOPS, } + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Benchmark FP8 attention for DeepSeek-R1") + parser = argparse.ArgumentParser( + description="Benchmark FP8 attention for DeepSeek-R1" + ) parser.add_argument( "--save-results-to", type=str, default=None, - help="Path to save benchmark results as CSV (optional)" + help="Path to save benchmark results as CSV (optional)", ) args = parser.parse_args() - + results = [] - + # Define configurations: (batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name) # DeepSeek-R1 uses MLA (Multi-head Latent Attention) with 128 heads # head_dim_qk=192 (128 nope + 64 rope), head_dim_vo=128 @@ -131,49 +134,88 @@ def bench_fmha_blackwell( (2, 4096, 128, 128, 192, 128, "DeepSeek-R1"), (1, 8192, 128, 128, 192, 128, "DeepSeek-R1"), ] - + # Run benchmarks: Causal first, then non-causal # For each config: bfloat16 then fp8 for causal in [True, False]: - print(f"\n{'='*80}") + print(f"\n{'=' * 80}") print(f"Running {'CAUSAL' if causal else 'NON-CAUSAL'} benchmarks") - print(f"{'='*80}") - - for batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name in configs: + print(f"{'=' * 80}") + + for ( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + config_name, + ) in configs: # Run bfloat16 - print(f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16") + print( + f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16" + ) result_bf16 = bench_fmha_blackwell( - batch_size, qkv_len, num_qo_heads, num_kv_heads, - head_dim_qk, head_dim_vo, causal, torch.bfloat16, + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + causal, + torch.bfloat16, o_data_type=torch.bfloat16, ) result_bf16["config_name"] = config_name results.append(result_bf16) - print(f" → {result_bf16['tflops']:.2f} TFLOPs/s, {result_bf16['time_ms']:.3f} ms") - + print( + f" → {result_bf16['tflops']:.2f} TFLOPs/s, {result_bf16['time_ms']:.3f} ms" + ) + # Run fp8 - print(f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8") + print( + f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8" + ) result_fp8 = bench_fmha_blackwell( - batch_size, qkv_len, num_qo_heads, num_kv_heads, - head_dim_qk, head_dim_vo, causal, torch.float8_e4m3fn, + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + causal, + torch.float8_e4m3fn, o_data_type=torch.bfloat16, ) result_fp8["config_name"] = config_name results.append(result_fp8) - speedup = result_fp8['tflops'] / result_bf16['tflops'] - print(f" → {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)") - + speedup = result_fp8["tflops"] / result_bf16["tflops"] + print( + f" → {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)" + ) + # Write results to CSV if requested if args.save_results_to: - fieldnames = ["config_name", "batch_size", "qkv_len", "num_qo_heads", "num_kv_heads", - "head_dim_qk", "head_dim_vo", "causal", "dtype", "time_ms", "tflops"] - - with open(args.save_results_to, 'w', newline='') as csvfile: + fieldnames = [ + "config_name", + "batch_size", + "qkv_len", + "num_qo_heads", + "num_kv_heads", + "head_dim_qk", + "head_dim_vo", + "causal", + "dtype", + "time_ms", + "tflops", + ] + + with open(args.save_results_to, "w", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for result in results: writer.writerow(result) - - print(f"\n{'='*80}") + + print(f"\n{'=' * 80}") print(f"Results saved to: {args.save_results_to}") - print(f"{'='*80}") + print(f"{'=' * 80}") diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index b5afc985a6..4598ef739c 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -59,7 +59,7 @@ using tvm::ffi::Optional; return __VA_ARGS__(); \ }); \ } else { \ - return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ using c_type_out = nv_bfloat16; \ return __VA_ARGS__(); \ }); \ @@ -129,9 +129,9 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff static_cast(qo_head_indices.data_ptr()), static_cast(batch_indices.data_ptr()), static_cast(o.data_ptr()), maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr, - mask_mode_code, sm_scale, scale_q, scale_k, scale_v, o_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, - q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len, - total_kv_len, max_qo_len, stream); + mask_mode_code, sm_scale, scale_q, scale_k, scale_v, o_scale, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n, + v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, stream); TVM_FFI_ICHECK_EQ(status, cudaSuccess) << "Cutlass FMHA forward pass failed" << cudaGetErrorString(status); diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 278954b404..65b3ed33c1 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2989,7 +2989,7 @@ def run( out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( q.shape[:-1] + v.shape[-1:], - dtype=self._cached_o_data_type, + dtype=out_dtype, device=q.device, ) else: @@ -3166,6 +3166,10 @@ def fmha_varlen( lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, return_lse: Literal[False] = False, ) -> torch.Tensor: ... @@ -3183,6 +3187,10 @@ def fmha_varlen( lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, return_lse: Literal[True] = True, ) -> Tuple[torch.Tensor, torch.Tensor]: ... diff --git a/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh index 877e9ac7d0..b92de63407 100644 --- a/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh +++ b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh @@ -76,8 +76,8 @@ struct FwdRunner { IdType* qo_segment_offsets, IdType* kv_segment_offsets, IdType* work_indptr, IdType* qo_tile_indices, IdType* qo_head_indices, IdType* batch_indices, DTypeOut* o, float* maybe_lse, int mask_mode_code, - double sm_scale, double q_scale, double k_scale, double v_scale, double o_scale, - int num_qo_heads, int num_kv_heads, int head_dim_qk, + double sm_scale, double q_scale, double k_scale, double v_scale, + double o_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, int q_stride_n, int q_stride_h, int k_stride_n, int k_stride_h, int v_stride_n, int v_stride_h, int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, cudaStream_t stream) { @@ -164,17 +164,17 @@ cudaError_t run_fmha_fwd(void* workspace_buffer, DTypeIn* q, DTypeIn* k, DTypeIn IdType* qo_segment_offsets, IdType* kv_segment_offsets, IdType* work_indptr, IdType* qo_tile_indices, IdType* qo_head_indices, IdType* batch_indices, DTypeOut* o, float* maybe_lse, int mask_mode_code, - double sm_scale, double q_scale, double k_scale, double v_scale, double o_scale, - int num_qo_heads, int num_kv_heads, int head_dim_qk, + double sm_scale, double q_scale, double k_scale, double v_scale, + double o_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, int q_stride_n, int q_stride_h, int k_stride_n, int k_stride_h, int v_stride_n, int v_stride_h, int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, cudaStream_t stream) { return FwdRunner::run( workspace_buffer, q, k, v, qo_segment_offsets, kv_segment_offsets, work_indptr, qo_tile_indices, qo_head_indices, batch_indices, o, maybe_lse, mask_mode_code, sm_scale, - q_scale, k_scale, v_scale, o_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, - stream); + q_scale, k_scale, v_scale, o_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, + total_qo_len, total_kv_len, max_qo_len, stream); } }; // namespace flashinfer diff --git a/tests/attention/test_blackwell_fmha.py b/tests/attention/test_blackwell_fmha.py index 130d77b699..39084d3687 100644 --- a/tests/attention/test_blackwell_fmha.py +++ b/tests/attention/test_blackwell_fmha.py @@ -359,7 +359,11 @@ def test_blackwell_cutlass_qo_kv_varlen( @pytest.mark.parametrize( "head_dim_qk,head_dim_vo,sm_scale", [ - (192, 128, 1.0 / math.sqrt(192)), # DeepSeek-R1: qk_nope(128) + qk_rope(64) = 192, v=128 + ( + 192, + 128, + 1.0 / math.sqrt(192), + ), # DeepSeek-R1: qk_nope(128) + qk_rope(64) = 192, v=128 ], ) @pytest.mark.parametrize("causal", [False, True]) @@ -381,11 +385,11 @@ def test_blackwell_cutlass_fmha_fp8( torch.device("cuda") ): pytest.skip("only SM100A and SM110A are supported on this device") - + torch.manual_seed(42) dtype_in = torch.float8_e4m3fn dtype_out = torch.bfloat16 - + # Create FP8 tensors by generating half precision then converting q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda" @@ -429,27 +433,29 @@ def test_blackwell_cutlass_fmha_fp8( gqa_group_ratio = num_qo_heads // num_kv_heads k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) - + # Reference implementation with FP8 inputs, upcast to float32, output as bfloat16 qo_len_ref = q.shape[0] // batch_size kv_len_ref = k_repeated.shape[0] // batch_size num_qo_heads_ref = q.shape[1] head_dim_qk_ref = q.shape[2] head_dim_vo_ref = v_repeated.shape[2] - + logits = ( torch.einsum( "bmhd,bnhd->bhmn", q.view(batch_size, qo_len_ref, num_qo_heads_ref, head_dim_qk_ref).float(), - k_repeated.view(batch_size, kv_len_ref, num_qo_heads_ref, head_dim_qk_ref).float(), + k_repeated.view( + batch_size, kv_len_ref, num_qo_heads_ref, head_dim_qk_ref + ).float(), ) * sm_scale ) if causal: - mask = torch.arange(kv_len_ref - qo_len_ref, kv_len_ref, device=q.device).unsqueeze( - 1 - ) >= torch.arange(0, kv_len_ref, device=q.device).unsqueeze(0) + mask = torch.arange( + kv_len_ref - qo_len_ref, kv_len_ref, device=q.device + ).unsqueeze(1) >= torch.arange(0, kv_len_ref, device=q.device).unsqueeze(0) else: mask = torch.ones(qo_len_ref, kv_len_ref, device=q.device) @@ -460,7 +466,9 @@ def test_blackwell_cutlass_fmha_fp8( torch.einsum( "bhmn,bnhd->bmhd", p, - v_repeated.view(batch_size, kv_len_ref, num_qo_heads_ref, head_dim_vo_ref).float(), + v_repeated.view( + batch_size, kv_len_ref, num_qo_heads_ref, head_dim_vo_ref + ).float(), ) .contiguous() .view(batch_size * qo_len_ref, num_qo_heads_ref, head_dim_vo_ref)