diff --git a/benchmarks/bench_blackwell_attention.py b/benchmarks/bench_blackwell_attention.py index 73b0cd0b3c..664cae9340 100644 --- a/benchmarks/bench_blackwell_attention.py +++ b/benchmarks/bench_blackwell_attention.py @@ -14,6 +14,8 @@ limitations under the License. """ +import argparse +import csv import numpy as np import torch @@ -27,20 +29,26 @@ def bench_fmha_blackwell( batch_size, qkv_len, - num_heads, - head_dim, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, causal, dtype, + o_data_type, ): + # if sizeof(dtype) == 1 like with torch.float8_e4m3fn, + # create randn from half and then convert to dtype + init_dtype = torch.half if dtype.itemsize == 1 else dtype q = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) + batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=init_dtype, device="cuda" + ).to(dtype) k = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) + batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=init_dtype, device="cuda" + ).to(dtype) v = torch.randn( - batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" - ) + batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=init_dtype, device="cuda" + ).to(dtype) qo_segment_offsets = ( torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len @@ -53,16 +61,19 @@ def bench_fmha_blackwell( kv_layout="NHD", backend="cutlass", ) + # For FP8 input, output must be bfloat16 + o_data_type = torch.bfloat16 if dtype.itemsize == 1 else dtype wrapper.plan( qo_segment_offsets, kv_segment_offsets, - num_heads, - num_heads, - head_dim, - head_dim_vo=head_dim, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, causal=causal, q_data_type=dtype, kv_data_type=dtype, + o_data_type=o_data_type, ) _o = wrapper.run(q, k, v) measurements = bench_gpu_time( @@ -75,52 +86,136 @@ def bench_fmha_blackwell( TFLOPS = attention_tflops_per_sec_with_actual_seq_lens( torch.full((batch_size,), qkv_len), torch.full((batch_size,), qkv_len), - head_dim, - head_dim, - num_heads, + head_dim_qk, + head_dim_vo, + num_qo_heads, causal, ms, ) print( - f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s" + f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim_qk={head_dim_qk}, head_dim_vo={head_dim_vo}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s" ) + return { + "config_name": f"Blackwell-{config_name}", + "batch_size": batch_size, + "qkv_len": qkv_len, + "num_qo_heads": num_qo_heads, + "num_kv_heads": num_kv_heads, + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, + "causal": causal, + "dtype": dtype, + "time_ms": ms, + "tflops": TFLOPS, + } if __name__ == "__main__": - print("\n === head_dim=128 ===") - bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16) - - bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16) - - print("\n === head_dim=64 ===") - bench_fmha_blackwell(128, 512, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 64, False, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 64, False, torch.bfloat16) - - bench_fmha_blackwell(128, 512, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(64, 1024, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(32, 2048, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(16, 4096, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(8, 8192, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(4, 16384, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(2, 32768, 32, 64, True, torch.bfloat16) - bench_fmha_blackwell(1, 65536, 32, 64, True, torch.bfloat16) + parser = argparse.ArgumentParser( + description="Benchmark FP8 attention for DeepSeek-R1" + ) + parser.add_argument( + "--save-results-to", + type=str, + default=None, + help="Path to save benchmark results as CSV (optional)", + ) + args = parser.parse_args() + + results = [] + + # Define configurations: (batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name) + # DeepSeek-R1 uses MLA (Multi-head Latent Attention) with 128 heads + # head_dim_qk=192 (128 nope + 64 rope), head_dim_vo=128 + configs = [ + (16, 512, 128, 128, 192, 128, "DeepSeek-R1"), + (8, 1024, 128, 128, 192, 128, "DeepSeek-R1"), + (4, 2048, 128, 128, 192, 128, "DeepSeek-R1"), + (2, 4096, 128, 128, 192, 128, "DeepSeek-R1"), + (1, 8192, 128, 128, 192, 128, "DeepSeek-R1"), + ] + + # Run benchmarks: Causal first, then non-causal + # For each config: bfloat16 then fp8 + for causal in [True, False]: + print(f"\n{'=' * 80}") + print(f"Running {'CAUSAL' if causal else 'NON-CAUSAL'} benchmarks") + print(f"{'=' * 80}") + + for ( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + config_name, + ) in configs: + # Run bfloat16 + print( + f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16" + ) + result_bf16 = bench_fmha_blackwell( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + causal, + torch.bfloat16, + o_data_type=torch.bfloat16, + ) + result_bf16["config_name"] = config_name + results.append(result_bf16) + print( + f" → {result_bf16['tflops']:.2f} TFLOPs/s, {result_bf16['time_ms']:.3f} ms" + ) + + # Run fp8 + print( + f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8" + ) + result_fp8 = bench_fmha_blackwell( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + causal, + torch.float8_e4m3fn, + o_data_type=torch.bfloat16, + ) + result_fp8["config_name"] = config_name + results.append(result_fp8) + speedup = result_fp8["tflops"] / result_bf16["tflops"] + print( + f" → {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)" + ) + + # Write results to CSV if requested + if args.save_results_to: + fieldnames = [ + "config_name", + "batch_size", + "qkv_len", + "num_qo_heads", + "num_kv_heads", + "head_dim_qk", + "head_dim_vo", + "causal", + "dtype", + "time_ms", + "tflops", + ] + + with open(args.save_results_to, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for result in results: + writer.writerow(result) + + print(f"\n{'=' * 80}") + print(f"Results saved to: {args.save_results_to}") + print(f"{'=' * 80}") diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index 08f1235adf..4598ef739c 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -58,6 +58,11 @@ using tvm::ffi::Optional; using c_type_out = c_type_in; \ return __VA_ARGS__(); \ }); \ + } else { \ + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ + using c_type_out = nv_bfloat16; \ + return __VA_ARGS__(); \ + }); \ } \ return false; \ }() @@ -80,14 +85,18 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff ffi::TensorView qo_tile_indices, ffi::TensorView qo_head_indices, ffi::TensorView batch_indices, ffi::TensorView o, Optional maybe_lse, int64_t mask_mode_code, - double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, - int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len) { + double sm_scale, double scale_q, double scale_k, double scale_v, + double o_scale, int64_t max_qo_len) { TVM_FFI_ICHECK_EQ(q.dtype(), k.dtype()); auto scalar_type_in = q.dtype(); auto scalar_type_out = o.dtype(); MaskMode mask_mode = static_cast(mask_mode_code); int total_qo_len = q.size(0); int total_kv_len = k.size(0); + int num_qo_heads = q.size(1); + int num_kv_heads = k.size(1); + int head_dim_qk = q.size(2); + int head_dim_vo = v.size(2); int batch_size = qo_segment_offsets.size(0) - 1; int q_stride_n = q.stride(0); int q_stride_h = q.stride(1); @@ -120,9 +129,9 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff static_cast(qo_head_indices.data_ptr()), static_cast(batch_indices.data_ptr()), static_cast(o.data_ptr()), maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr, - mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, - q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len, - total_kv_len, max_qo_len, stream); + mask_mode_code, sm_scale, scale_q, scale_k, scale_v, o_scale, num_qo_heads, num_kv_heads, + head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n, + v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, stream); TVM_FFI_ICHECK_EQ(status, cudaSuccess) << "Cutlass FMHA forward pass failed" << cudaGetErrorString(status); diff --git a/csrc/fmha_cutlass_sm100_binding.cu b/csrc/fmha_cutlass_sm100_binding.cu index ddb3b8d9cd..2668d97c53 100644 --- a/csrc/fmha_cutlass_sm100_binding.cu +++ b/csrc/fmha_cutlass_sm100_binding.cu @@ -22,8 +22,8 @@ void FMHACutlassSM100Run(TensorView workspace_buffer, TensorView q, TensorView k TensorView work_indptr, TensorView qo_tile_indices, TensorView qo_head_indices, TensorView batch_indices, TensorView o, Optional maybe_lse, int64_t mask_mode_code, double sm_scale, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, - int64_t head_dim_vo, int64_t max_qo_len); + double scale_q, double scale_k, double scale_v, double o_scale, + int64_t max_qo_len); void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets, TensorView work_indptr, TensorView qo_tile_indices, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 41fac0e4e9..65b3ed33c1 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2908,6 +2908,10 @@ def run( k: torch.Tensor, v: torch.Tensor, *args, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, @@ -2926,6 +2930,14 @@ def run( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim_vo]`` *args Additional arguments for the custom kernel. + q_scale: Optional[float] + The calibration scale of fp8 query, if not provided, will be set to ``1.0``. + k_scale: Optional[float] + The calibration scale of fp8 key, if not provided, will be set to ``1.0``. + v_scale: Optional[float] + The calibration scale of fp8 value, if not provided, will be set to ``1.0``. + o_scale: Optional[float] + The calibration scale of output, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. lse : Optional[torch.Tensor] @@ -2973,9 +2985,11 @@ def run( lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" ) if out is None: + # when input dtype is fp8, we need to use bf16 output + out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( q.shape[:-1] + v.shape[-1:], - dtype=self._cached_o_data_type, + dtype=out_dtype, device=q.device, ) else: @@ -2996,6 +3010,10 @@ def run( plan_info=self._plan_info, causal=self._causal, sm_scale=sm_scale, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + o_scale=o_scale, max_qo_len=self._max_qo_len, out=out, lse=lse, @@ -3148,6 +3166,10 @@ def fmha_varlen( lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, return_lse: Literal[False] = False, ) -> torch.Tensor: ... @@ -3165,6 +3187,10 @@ def fmha_varlen( lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, return_lse: Literal[True] = True, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @@ -3181,6 +3207,10 @@ def fmha_varlen( lse: Optional[torch.Tensor] = None, causal: bool = False, sm_scale: Optional[float] = None, + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + o_scale: Optional[float] = None, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: workspace_buffer = _get_cache_buf( @@ -3205,6 +3235,14 @@ def fmha_varlen( mask_mode_code = 1 if causal else 0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim_qk) + if q_scale is None: + q_scale = 1.0 + if k_scale is None: + k_scale = 1.0 + if v_scale is None: + v_scale = 1.0 + if o_scale is None: + o_scale = 1.0 qo_total_len = nnz_qo if max_qo_len is None: @@ -3223,12 +3261,14 @@ def fmha_varlen( ) = plan_info if out is None: + # when input dtype is fp8, we need to use bf16 output + out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype out = torch.empty( qo_total_len + max(max_qo_len, 128), num_qo_heads, head_dim_vo, device=q.device, - dtype=q.dtype, + dtype=out_dtype, )[max(max_qo_len, 128) :] if lse is None and return_lse: @@ -3251,10 +3291,10 @@ def fmha_varlen( lse, mask_mode_code, sm_scale, - num_qo_heads, - num_kv_heads, - head_dim_qk, - head_dim_vo, + q_scale, + k_scale, + v_scale, + o_scale, max_qo_len, ) diff --git a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 640b876b49..a973117548 100644 --- a/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -64,9 +64,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { using Mask = Mask_; static constexpr int StageCountQ = 2; - static constexpr int StageCountKV = (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64) - ? 2 - : 1; // sizeof(Element_) == 1 ? 2 : 2; + static constexpr int StageCountKV = + (sizeof(Element_) == 1) + ? (get<2>(TileShapeQK{}) == 128 ? 4 : 2) + : (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64 ? 2 : 1); using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; diff --git a/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh index 66d8b4fb91..b92de63407 100644 --- a/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh +++ b/include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh @@ -76,7 +76,8 @@ struct FwdRunner { IdType* qo_segment_offsets, IdType* kv_segment_offsets, IdType* work_indptr, IdType* qo_tile_indices, IdType* qo_head_indices, IdType* batch_indices, DTypeOut* o, float* maybe_lse, int mask_mode_code, - double sm_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, + double sm_scale, double q_scale, double k_scale, double v_scale, + double o_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, int q_stride_n, int q_stride_h, int k_stride_n, int k_stride_h, int v_stride_n, int v_stride_h, int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, cudaStream_t stream) { @@ -120,7 +121,7 @@ struct FwdRunner { typename Operation::Arguments arguments{ problem_shape, - {q, layout_Q, k, layout_K, v, layout_V, sm_scale}, + {q, layout_Q, k, layout_K, v, layout_V, sm_scale, q_scale, k_scale, v_scale, o_scale}, {o - max_qo_len * get<0>(stride_O), layout_O, maybe_lse, layout_LSE, max_qo_len}, {work_indptr, qo_tile_indices, qo_head_indices, batch_indices}, hw_info}; @@ -163,16 +164,17 @@ cudaError_t run_fmha_fwd(void* workspace_buffer, DTypeIn* q, DTypeIn* k, DTypeIn IdType* qo_segment_offsets, IdType* kv_segment_offsets, IdType* work_indptr, IdType* qo_tile_indices, IdType* qo_head_indices, IdType* batch_indices, DTypeOut* o, float* maybe_lse, int mask_mode_code, - double sm_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, + double sm_scale, double q_scale, double k_scale, double v_scale, + double o_scale, int num_qo_heads, int num_kv_heads, int head_dim_qk, int head_dim_vo, int q_stride_n, int q_stride_h, int k_stride_n, int k_stride_h, int v_stride_n, int v_stride_h, int batch_size, int total_qo_len, int total_kv_len, int max_qo_len, cudaStream_t stream) { return FwdRunner::run( workspace_buffer, q, k, v, qo_segment_offsets, kv_segment_offsets, work_indptr, qo_tile_indices, qo_head_indices, batch_indices, o, maybe_lse, mask_mode_code, sm_scale, - num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, - stream); + q_scale, k_scale, v_scale, o_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, + total_qo_len, total_kv_len, max_qo_len, stream); } }; // namespace flashinfer diff --git a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index d6e913a319..ebf38d7347 100644 --- a/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -66,8 +66,8 @@ struct Sm100FmhaCtxKernelWarpspecializedSchedule { static const bool kDebugUsingPrintf = false; static const int NumRegsSoftmax = 192; - static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); - static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0); static const int NumRegsEmpty = 24; static const int NumWarps = 16; diff --git a/tests/attention/test_blackwell_fmha.py b/tests/attention/test_blackwell_fmha.py index 298bfa5db4..39084d3687 100644 --- a/tests/attention/test_blackwell_fmha.py +++ b/tests/attention/test_blackwell_fmha.py @@ -347,38 +347,183 @@ def test_blackwell_cutlass_qo_kv_varlen( torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) -if __name__ == "__main__": - test_blackwell_cutlass_fmha( - 9, - 377, - 977, - 1, - 1, - 192, - 128, - 1, - False, - torch.bfloat16, +@pytest.mark.parametrize("batch_size", [1, 2, 9, 12]) +@pytest.mark.parametrize("qo_len", [177, 377]) +@pytest.mark.parametrize("kv_len", [544, 977]) +@pytest.mark.parametrize( + "num_qo_heads,num_kv_heads", + [ + (128, 128), # DeepSeek-R1 MHA (Multi-head Attention for Prefill) + ], +) +@pytest.mark.parametrize( + "head_dim_qk,head_dim_vo,sm_scale", + [ + ( + 192, + 128, + 1.0 / math.sqrt(192), + ), # DeepSeek-R1: qk_nope(128) + qk_rope(64) = 192, v=128 + ], +) +@pytest.mark.parametrize("causal", [False, True]) +def test_blackwell_cutlass_fmha_fp8( + batch_size, + qo_len, + kv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + sm_scale, + causal, +): + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + if not is_sm100a_supported(torch.device("cuda")) and not is_sm110a_supported( + torch.device("cuda") + ): + pytest.skip("only SM100A and SM110A are supported on this device") + + torch.manual_seed(42) + dtype_in = torch.float8_e4m3fn + dtype_out = torch.bfloat16 + + # Create FP8 tensors by generating half precision then converting + q = torch.randn( + batch_size * qo_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype_in) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_qk, dtype=torch.half, device="cuda" + ).to(dtype_in) + v = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim_vo, dtype=torch.half, device="cuda" + ).to(dtype_in) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=sm_scale, + q_data_type=dtype_in, + kv_data_type=dtype_in, + o_data_type=dtype_out, ) + o, lse = wrapper.run(q, k, v, return_lse=True) - test_blackwell_cutlass_varlen( - [0, 1274, 2568, 3915, 5194, 6498, 7839, 8192], - 32, - 4, - 128, - 128, - 1, - True, - torch.bfloat16, + # Verify output is bfloat16 + assert o.dtype == dtype_out, f"Expected output dtype {dtype_out}, got {o.dtype}" + + gqa_group_ratio = num_qo_heads // num_kv_heads + k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + + # Reference implementation with FP8 inputs, upcast to float32, output as bfloat16 + qo_len_ref = q.shape[0] // batch_size + kv_len_ref = k_repeated.shape[0] // batch_size + num_qo_heads_ref = q.shape[1] + head_dim_qk_ref = q.shape[2] + head_dim_vo_ref = v_repeated.shape[2] + + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len_ref, num_qo_heads_ref, head_dim_qk_ref).float(), + k_repeated.view( + batch_size, kv_len_ref, num_qo_heads_ref, head_dim_qk_ref + ).float(), + ) + * sm_scale ) - test_blackwell_cutlass_qo_kv_varlen( - [0, 10, 20, 30, 40, 50, 60, 100], - [0, 50, 50, 50, 50, 50, 50, 50], - 32, - 8, - 128, - 128, - 1, - torch.bfloat16, + if causal: + mask = torch.arange( + kv_len_ref - qo_len_ref, kv_len_ref, device=q.device + ).unsqueeze(1) >= torch.arange(0, kv_len_ref, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len_ref, kv_len_ref, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v_repeated.view( + batch_size, kv_len_ref, num_qo_heads_ref, head_dim_vo_ref + ).float(), + ) + .contiguous() + .view(batch_size * qo_len_ref, num_qo_heads_ref, head_dim_vo_ref) + .to(dtype_out) # Convert to bfloat16 for FP8 output + ) + lse_ref = (lse_ref * math.log2(math.e)).flatten(0, 1) + + # FP8 has lower precision, use relaxed tolerances + torch.testing.assert_close(o, o_ref, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_blackwell_cutlass_fmha_fp8( + batch_size=9, + qo_len=377, + kv_len=977, + num_qo_heads=1, + num_kv_heads=1, + head_dim_qk=192, + head_dim_vo=128, + sm_scale=1, + causal=False, ) + # test_blackwell_cutlass_fmha( + # 9, + # 377, + # 977, + # 1, + # 1, + # 192, + # 128, + # 1, + # False, + # torch.bfloat16, + # ) + + # test_blackwell_cutlass_varlen( + # [0, 1274, 2568, 3915, 5194, 6498, 7839, 8192], + # 32, + # 4, + # 128, + # 128, + # 1, + # True, + # torch.bfloat16, + # ) + + # test_blackwell_cutlass_qo_kv_varlen( + # [0, 10, 20, 30, 40, 50, 60, 100], + # [0, 50, 50, 50, 50, 50, 50, 50], + # 32, + # 8, + # 128, + # 128, + # 1, + # torch.bfloat16, + # )