diff --git a/benchmarks/README.md b/benchmarks/README.md index d81e9c3642..162c166cee 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -117,7 +117,7 @@ The output CSV will contain detailed metrics including: | `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) | | `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. | | `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. | -| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-native, cublas| +| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cublas| ### Attention Flags | Flag | Description | diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 320cfbe020..5838581a05 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -4,6 +4,22 @@ import torch import flashinfer + +# Try to import cudnn for version checking +CUDNN_AVAILABLE = False +CUDNN_BACKEND_VERSION = 0 +try: + import cudnn + + CUDNN_AVAILABLE = True + CUDNN_BACKEND_VERSION = cudnn.backend_version() +except ImportError: + pass +except OSError as e: + error_msg = str(e).lower() + is_lib_missing = any(ext in error_msg for ext in [".so", ".dll"]) + if not is_lib_missing: + raise from flashinfer.testing.utils import ( attention_tb_per_sec_with_actual_seq_lens, attention_tflops_per_sec_with_actual_seq_lens, @@ -88,6 +104,7 @@ def parse_attention_args(line, parser): "fa2_tc", "fa3", "cudnn", + "cudnn-native", "cutlass", "trtllm-gen", "trtllm-native", @@ -680,6 +697,14 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): print(f"[ERROR] Unsupported kv_dtype: {args.kv_dtype}") return res + # Increase tolerances for FP8 due to lower precision + if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + rtol = 5e-1 # Relaxed relative tolerance for FP8 + atol = 1e-1 # Relaxed absolute tolerance for FP8 + # Parse and validate backend configurations backends = args.backends page_size = args.page_size @@ -706,15 +731,36 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): backends.remove("fa2") if "cudnn" in backends: remove_cudnn = False + # cuDNN FP8 prefill requires cuDNN >= 9.18.0 (backend version 91800) if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, ]: - print("[INFO] cuDNN backend does not support FP8. Skipping.") - remove_cudnn = True + if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91800: + print( + f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.18.0. " + f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn backend." + ) + remove_cudnn = True if remove_cudnn: backends.remove("cudnn") + if "cudnn-native" in backends: + remove_cudnn_native = False + # cuDNN-native does not yet support FP8 prefill + if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91800: + print( + f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.18.0. " + f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn-native backend." + ) + remove_cudnn_native = True + if remove_cudnn_native: + backends.remove("cudnn-native") + if "trtllm-gen" in backends: remove_trtllm = False if not causal: @@ -908,7 +954,44 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): print(f"[VVERBOSE] {kv_last_page_len.shape = }") print(f"[VVERBOSE] {scale = }") - # Prepare wrappers + # Helper function to convert to FP8 (matches test_trtllm_gen_attention.py approach) + def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + # Compute scales and convert to FP8 if needed (before creating wrappers) + q_scale, k_scale, v_scale = None, None, None + q_scale_tensor, k_scale_tensor, v_scale_tensor = None, None, None + o_data_type = q_dtype # Default output dtype + # Separate K/V caches for cuDNN (which requires separate tensors, not combined kv_cache) + k_cache_cudnn, v_cache_cudnn = k_cache, v_cache + + if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + q, q_scale_t = to_float8(q, q_dtype) + q_scale = q_scale_t.item() + q_scale_tensor = q_scale_t.reshape(1, 1, 1, 1) + # o_data_type stays as q_dtype (FP8 output) + + if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + # Convert k_cache and v_cache to quantized dtype for cuDNN + k_cache_cudnn, k_scale_t = to_float8(k_cache, kv_dtype) + v_cache_cudnn, v_scale_t = to_float8(v_cache, kv_dtype) + k_scale = k_scale_t.item() + v_scale = v_scale_t.item() + k_scale_tensor = k_scale_t.reshape(1, 1, 1, 1) + v_scale_tensor = v_scale_t.reshape(1, 1, 1, 1) + + # Also convert the full kv_cache for non-cuDNN backends + k_data, v_data = torch.chunk(kv_cache, 2, dim=1) + k_quantized, _ = to_float8(k_data, kv_dtype) + v_quantized, _ = to_float8(v_data, kv_dtype) + kv_cache = torch.cat([k_quantized, v_quantized], dim=1) + + # Prepare wrappers (after FP8 conversion so we have correct dtypes) backend_wrappers = {} for backend in backends: if backend in ["fa2", "fa3", "trtllm-gen"]: @@ -941,28 +1024,78 @@ def testBatchPrefillWithPagedKVCacheWrapper(args): kv_data_type=kv_dtype, block_tables=block_tables, ) - - k_scale, v_scale = None, None - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - q = q.to(q_dtype) - if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - k_data, v_data = torch.chunk(kv_cache, 2, dim=1) - k_scale = k_data.amax().item() / 256 - v_scale = v_data.amax().item() / 256 - k_fp8 = (k_data / k_scale).to(kv_dtype) - v_fp8 = (v_data / v_scale).to(kv_dtype) - kv_cache = torch.cat([k_fp8, v_fp8], dim=1) + elif backend == "cudnn": + # cuDNN uses NHD layout and the wrapper API + backend_wrappers[backend] = ( + flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + backend="cudnn", + ) + ) + backend_wrappers["cudnn"].plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + page_size, + pos_encoding_mode="NONE", + causal=causal, + q_data_type=q_dtype, + o_data_type=o_data_type, + seq_lens=actual_seq_lens_kv_device, + seq_lens_q=actual_seq_lens_q_device, + sm_scale=scale, + max_token_per_sequence=s_qo, + max_sequence_kv=s_kv, + block_tables=block_tables, + ) def run_backend_wrapper(backend): if backend in ["fa2", "fa3", "trtllm-gen"]: return backend_wrappers[backend].run( - q, kv_cache, k_scale=k_scale, v_scale=v_scale + q, kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale ) elif backend == "cudnn": + # cuDNN uses wrapper API with tensor scales for FP8 + return backend_wrappers[backend].run( + q, + (k_cache_cudnn, v_cache_cudnn), + q_scale=q_scale_tensor, + k_scale=k_scale_tensor, + v_scale=v_scale_tensor, + ) + elif backend == "trtllm-native": + # Compute combined bmm1_scale: q_scale * k_scale * sm_scale + # For FP8: all scales are float values + _q_scale = q_scale if q_scale is not None else 1.0 + _k_scale = k_scale if k_scale is not None else 1.0 + _v_scale = v_scale if v_scale is not None else 1.0 + bmm1_scale = _q_scale * _k_scale * scale + bmm2_scale = _v_scale + return flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=q, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + block_tables=block_tables, + seq_lens=actual_seq_lens_kv_device, + max_q_len=s_qo, + max_kv_len=s_kv, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + batch_size=batch_size, + cum_seq_lens_q=qo_indptr, + cum_seq_lens_kv=kv_indptr, + ) + elif backend == "cudnn-native": + # Direct cudnn_batch_prefill_with_kv_cache call (similar to trtllm-native) return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache( q, - k_cache, - v_cache, + k_cache_cudnn, + v_cache_cudnn, scale, workspace_buffer, max_token_per_sequence=s_qo, @@ -975,27 +1108,17 @@ def run_backend_wrapper(backend): is_cuda_graph_compatible=is_cuda_graph_compatible, batch_offsets_q=q_indptr, batch_offsets_o=q_indptr, + q_scale=q_scale_tensor, + k_scale=k_scale_tensor, + v_scale=v_scale_tensor, + o_data_type=o_data_type, )[0] - elif backend == "trtllm-native": - return flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, - kv_cache=kv_cache, - workspace_buffer=workspace_buffer, - block_tables=block_tables, - seq_lens=actual_seq_lens_kv_device, - max_q_len=s_qo, - max_kv_len=s_kv, - bmm1_scale=scale if k_scale is None else k_scale * scale, - bmm2_scale=1.0 if v_scale is None else v_scale, - batch_size=batch_size, - cum_seq_lens_q=qo_indptr, - cum_seq_lens_kv=kv_indptr, - ) else: print(f"[ERROR] Backend {backend} not supported") return res has_reference_output = False + reference_backend = None # Iterate over each backend: for cur_backend in backends: # Clear workspace buffer to prevent unexpected interactions between backends. @@ -1005,6 +1128,7 @@ def run_backend_wrapper(backend): if cur_backend == "fa2": has_reference_output = True reference_output = outputs[cur_backend] + reference_backend = "fa2" backend_times[cur_backend] = bench_gpu_time( fn=lambda: run_backend_wrapper(cur_backend), dry_run_iters=args.dry_run_iters, @@ -1020,6 +1144,22 @@ def run_backend_wrapper(backend): # Perform reference check tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) + + # When cases where FA2 is not available, try to find an alternative reference + # Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native + if run_refcheck and not has_reference_output and len(tested_backends) > 1: + reference_priority = ["cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"] + for candidate in reference_priority: + if candidate in tested_backends: + has_reference_output = True + reference_backend = candidate + reference_output = outputs[candidate] + if args.verbose >= 1: + print( + f"[INFO] FA2 not available for reference. Using {candidate} as reference backend for cross-comparison." + ) + break + if len(tested_backends) > 1: if run_refcheck and has_reference_output: if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: @@ -1037,7 +1177,7 @@ def run_backend_wrapper(backend): ) = is_close_stats(reference_output, tested_outputs[i], rtol, atol) if num_different_elements > 0: print( - f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: " + f"[ERROR] Output tensor mismatch between backends {reference_backend} and {tested_backends[i]}: " f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different" ) if not args.allow_output_mismatch: diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index d5f363839a..0857c9f9f8 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -173,14 +173,15 @@ def dtype_str_to_torch_dtype(dtype_str): }, "BatchPrefillWithPagedKVCacheWrapper": { # NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache + # NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache "7.5": [], - "8.0": ["fa2", "cudnn"], - "8.6": ["fa2", "cudnn"], - "8.9": ["fa2", "cudnn"], - "9.0": ["fa2", "fa3", "cudnn"], - "10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"], - "10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"], - "12.0": ["fa2", "cudnn"], + "8.0": ["fa2", "cudnn", "cudnn-native"], + "8.6": ["fa2", "cudnn", "cudnn-native"], + "8.9": ["fa2", "cudnn", "cudnn-native"], + "9.0": ["fa2", "fa3", "cudnn", "cudnn-native"], + "10.0": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"], + "10.3": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"], + "12.0": ["fa2", "cudnn", "cudnn-native"], }, "BatchPrefillWithRaggedKVCacheWrapper": { # NOTE: trtllm-native calls trtllm_ragged_attention_deepseek diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index b8c09a66ee..ad49a74cbd 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -17,9 +17,20 @@ # Global cudnn handle. need to make it per device in future _cudnn_handle = None +_dummy_scale_tensors: dict[torch.device, torch.Tensor] = {} + + +def _get_dummy_scale_tensor(device: torch.device): + t = _dummy_scale_tensors.get(device) + if t is None: + t = torch.tensor([1.0], device=device, dtype=torch.float32).reshape(1, 1, 1, 1) + _dummy_scale_tensors[device] = t + return t + def _create_cudnn_handle(stream: torch.cuda.Stream): global _cudnn_handle + if _cudnn_handle is None: _cudnn_handle = cudnn.create_handle() cudnn.set_stream(_cudnn_handle, stream.cuda_stream) @@ -50,6 +61,16 @@ class UIDs(Enum): O_UID = 1000 # Output tensor STATS_UID = 1001 # Stats tensor + Q_SCALE_UID = 150 # Query scale tensor + K_SCALE_UID = 151 # Key scale tensor + V_SCALE_UID = 152 # Value scale tensor + S_SCALE_UID = 153 # Scale tensor + S_DESCALE_UID = 154 # Descale tensor + O_SCALE_UID = 155 # Output scale tensor + + S_AMAX_UID = 160 # Scale amax tensor + O_AMAX_UID = 161 # Output amax tensor + def _sdpa_prefill_key_fn( q: torch.Tensor, @@ -72,6 +93,7 @@ def _sdpa_prefill_key_fn( batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + o_data_type: Optional[torch.dtype] = None, ): graph_b = actual_seq_lens_q.shape[0] @@ -91,6 +113,7 @@ def _sdpa_prefill_key_fn( key = ( graph_b, q.dim(), + q.dtype, k_cache.dim(), max_token_seq_q, max_sequence_kv, @@ -130,6 +153,7 @@ def _build_prefill_graph( batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + o_data_type: Optional[torch.dtype] = None, ): handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) @@ -137,6 +161,26 @@ def _build_prefill_graph( graph_s_qo = max_token_seq_q graph_s_kv = max_sequence_kv + if not cudnn.datatypes.is_torch_available(): + raise RuntimeError("torch is not available") + + cudnn_q_data_type = cudnn.datatypes._torch_to_cudnn_data_type(q.dtype) + cudnn_k_data_type = cudnn.datatypes._torch_to_cudnn_data_type(k_cache.dtype) + cudnn_v_data_type = cudnn.datatypes._torch_to_cudnn_data_type(v_cache.dtype) + + if o_data_type is None: + o_data_type = q.dtype + + cudnn_o_data_type = cudnn.datatypes._torch_to_cudnn_data_type(o_data_type) + + if ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ) and cudnn.backend_version() < 91800: + raise RuntimeError( + f"FP8 is not supported in cuDNN backend version < 9.18.0, current version is {cudnn.backend_version()}" + ) + with cudnn.graph(handle) as (g, _): # Create tensors from the input tensors if q.dim() == 3: @@ -150,9 +194,62 @@ def _build_prefill_graph( name="q", dim=(graph_b, h_qo, graph_s_qo, d_qk), stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_q_data_type, ) + if ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ): + cudnn_q_scale = g.tensor( + name="q_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_k_scale = g.tensor( + name="k_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_v_scale = g.tensor( + name="v_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_s_scale = g.tensor( + name="s_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_s_descale = g.tensor( + name="s_descale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_o_scale = g.tensor( + name="o_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_q_scale.set_uid(UIDs.Q_SCALE_UID.value) + cudnn_k_scale.set_uid(UIDs.K_SCALE_UID.value) + cudnn_v_scale.set_uid(UIDs.V_SCALE_UID.value) + cudnn_s_scale.set_uid(UIDs.S_SCALE_UID.value) + cudnn_s_descale.set_uid(UIDs.S_DESCALE_UID.value) + cudnn_o_scale.set_uid(UIDs.O_SCALE_UID.value) + if batch_offsets_q is not None: ragged_q = g.tensor_like(batch_offsets_q) ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) @@ -176,7 +273,7 @@ def _build_prefill_graph( name="k_cache", dim=(graph_b, h_kv, graph_s_kv, d_qk), stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_k_data_type, ) if batch_offsets_k is not None: @@ -188,7 +285,7 @@ def _build_prefill_graph( name="v_cache", dim=(graph_b, h_kv, graph_s_kv, d_vo), stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_v_data_type, ) if batch_offsets_v is not None: @@ -201,14 +298,14 @@ def _build_prefill_graph( name="k_cache", dim=k_cache.shape, stride=k_cache.stride(), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_k_data_type, ) cudnn_v_cache = g.tensor( name="v_cache", dim=v_cache.shape, stride=v_cache.stride(), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_v_data_type, ) cudnn_q.set_uid(UIDs.Q_UID.value) @@ -239,32 +336,86 @@ def _build_prefill_graph( actual_seq_lens_q is not None and actual_seq_lens_kv is not None ) - O, Stats = g.sdpa( - name="sdpa", - q=cudnn_q, - k=cudnn_k_cache, - v=cudnn_v_cache, - seq_len_q=( - cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None - ), - seq_len_kv=( - cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None - ), - use_padding_mask=padding_mask, - attn_scale=scale, - generate_stats=return_lse, - use_causal_mask_bottom_right=bottom_right_causal_mask, - paged_attention_k_table=( - cudnn_k_block_tables if block_tables is not None else None - ), - paged_attention_v_table=( - cudnn_v_block_tables if block_tables is not None else None - ), - paged_attention_max_seq_len_kv=( - graph_s_kv if block_tables is not None else None - ), - compute_data_type=cudnn.data_type.FLOAT, - ) + if ( + cudnn_q_data_type == cudnn.data_type.BFLOAT16 + or cudnn_q_data_type == cudnn.data_type.HALF + ): + O, Stats = g.sdpa( + name="sdpa", + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + seq_len_q=( + cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None + ), + use_padding_mask=padding_mask, + attn_scale=scale, + generate_stats=return_lse, + use_causal_mask_bottom_right=bottom_right_causal_mask, + paged_attention_k_table=( + cudnn_k_block_tables if block_tables is not None else None + ), + paged_attention_v_table=( + cudnn_v_block_tables if block_tables is not None else None + ), + paged_attention_max_seq_len_kv=( + graph_s_kv if block_tables is not None else None + ), + compute_data_type=cudnn.data_type.FLOAT, + ) + + elif ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ): + O, Stats, amax_s, amax_o = g.sdpa_fp8( + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + descale_q=cudnn_q_scale, + descale_k=cudnn_k_scale, + descale_v=cudnn_v_scale, + scale_s=cudnn_s_scale, + descale_s=cudnn_s_descale, + scale_o=cudnn_o_scale, + generate_stats=True, + attn_scale=scale, + use_causal_mask_bottom_right=bottom_right_causal_mask, + use_padding_mask=padding_mask, + seq_len_q=( + cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None + ), + paged_attention_k_table=( + cudnn_k_block_tables if block_tables is not None else None + ), + paged_attention_v_table=( + cudnn_v_block_tables if block_tables is not None else None + ), + paged_attention_max_seq_len_kv=( + graph_s_kv if block_tables is not None else None + ), + ) + + amax_s.set_uid(UIDs.S_AMAX_UID.value).set_output(False).set_dim( + (1, 1, 1, 1) + ).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_o.set_uid(UIDs.O_AMAX_UID.value).set_output(False).set_dim( + (1, 1, 1, 1) + ).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) if batch_offsets_o is not None: ragged_o = g.tensor_like(batch_offsets_o) @@ -280,7 +431,7 @@ def _build_prefill_graph( [graph_b, h_qo, graph_s_qo, d_vo] ).set_stride( [graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1] - ).set_data_type(cudnn.data_type.BFLOAT16) + ).set_data_type(cudnn_o_data_type) if return_lse: Stats.set_uid(UIDs.STATS_UID.value).set_output( @@ -315,6 +466,9 @@ def _batch_prefill_with_kv_cache( block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, + q_scale: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, @@ -322,6 +476,7 @@ def _batch_prefill_with_kv_cache( batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + o_data_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor]: graph, tensors = _build_prefill_graph( q=q, @@ -342,6 +497,7 @@ def _batch_prefill_with_kv_cache( batch_offsets_stats=batch_offsets_stats, out=out, lse=lse, + o_data_type=o_data_type, ) var_map = { @@ -375,6 +531,17 @@ def _batch_prefill_with_kv_cache( if batch_offsets_stats is not None: var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats + if q_scale is not None: + dummy_scale_tensor = _get_dummy_scale_tensor(q.device) + var_map[UIDs.Q_SCALE_UID.value] = q_scale + var_map[UIDs.S_SCALE_UID.value] = dummy_scale_tensor + var_map[UIDs.S_DESCALE_UID.value] = dummy_scale_tensor + var_map[UIDs.O_SCALE_UID.value] = dummy_scale_tensor + if k_scale is not None: + var_map[UIDs.K_SCALE_UID.value] = k_scale + if v_scale is not None: + var_map[UIDs.V_SCALE_UID.value] = v_scale + handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) graph.execute(var_map, workspace=workspace_buffer, handle=handle) @@ -399,6 +566,9 @@ def cudnn_batch_prefill_with_kv_cache( block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, + q_scale: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, @@ -408,6 +578,7 @@ def cudnn_batch_prefill_with_kv_cache( lse: Optional[torch.Tensor] = None, is_cuda_graph_compatible: bool = False, backend: Optional[str] = None, + o_data_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Performs batched prefill attention with paged KV cache using cuDNN. @@ -427,11 +598,14 @@ def cudnn_batch_prefill_with_kv_cache( out: Optional pre-allocated output tensor lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph + q_scale: Optional scale tensor for query tensor of shape (1, 1, 1, 1) on GPU + k_scale: Optional scale tensor for key tensor of shape (1, 1, 1, 1) on GPU + v_scale: Optional scale tensor for value tensor of shape (1, 1, 1, 1) on GPU batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU batch_offsets_v: Optional batch offsets for value tensor of shape (batch_size,) on GPU - + o_data_type: Optional data type for output tensor Returns: Output tensor of shape (batch_size * seq_len_q, num_heads_qo, head_dim) If return_lse is True, also returns log-sum-exp tensor of shape (batch_size, seq_len_q, num_heads_qo) @@ -472,9 +646,12 @@ def cudnn_batch_prefill_with_kv_cache( "lse must have shape (num_sequences, max_token_per_sequence, h_qo)" ) + if o_data_type is None: + o_data_type = q.dtype + if out is None: out_shape = (num_tokens, h_qo, d_vo) - out = torch.empty(out_shape, device=q.device, dtype=q.dtype) + out = torch.empty(out_shape, device=q.device, dtype=o_data_type) if CUDNN_AVAILABLE and backend != "cubin": return _batch_prefill_with_kv_cache( @@ -490,6 +667,9 @@ def cudnn_batch_prefill_with_kv_cache( block_tables=block_tables, causal=causal, return_lse=return_lse, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, batch_offsets_q=batch_offsets_q, batch_offsets_o=batch_offsets_o, batch_offsets_k=batch_offsets_k, @@ -497,6 +677,7 @@ def cudnn_batch_prefill_with_kv_cache( batch_offsets_stats=batch_offsets_stats, out=out, lse=lse, + o_data_type=o_data_type, ) else: assert return_lse, "Currently only supports return_lse = True" diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 9fd1a5c0fa..a6e32a671f 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2051,9 +2051,9 @@ def run( q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], *args, - q_scale: Optional[float] = None, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, + q_scale: Optional[Union[float, torch.Tensor]] = None, + k_scale: Optional[Union[float, torch.Tensor]] = None, + v_scale: Optional[Union[float, torch.Tensor]] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, @@ -2083,9 +2083,11 @@ def run( *args Additional arguments for custom kernels. - k_scale : Optional[float] + q_scale : Optional[Union[float, torch.Tensor]] + The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. + k_scale : Optional[Union[float, torch.Tensor]] The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] + v_scale : Optional[Union[float, torch.Tensor]] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. @@ -2129,10 +2131,11 @@ def run( logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale + if self._backend != "cudnn": + if q_scale is not None: + sm_scale *= q_scale + if k_scale is not None: + sm_scale *= k_scale if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -2150,7 +2153,7 @@ def run( if out is None: # Use cached output data type if available (for FP8 attention with FP16 output) out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype - out = torch.empty( + out = torch.zeros( q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device ) else: @@ -2196,10 +2199,14 @@ def run( block_tables=self._block_tables, causal=self._causal, return_lse=return_lse, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, batch_offsets_q=self._qo_indptr_buf, batch_offsets_o=self._qo_indptr_buf, out=out, lse=lse, + o_data_type=out_dtype, ) else: if self._backend != "trtllm-gen": diff --git a/tests/attention/test_cudnn_prefill.py b/tests/attention/test_cudnn_prefill.py index d264db8ae4..3fab96a41a 100644 --- a/tests/attention/test_cudnn_prefill.py +++ b/tests/attention/test_cudnn_prefill.py @@ -41,7 +41,7 @@ def test_cudnn_prefill( ) cumsum_s_qo = torch.sum(actual_seq_lens_q) - q = torch.randn( + q = torch.ones( cumsum_s_qo, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16 ) @@ -57,7 +57,7 @@ def test_cudnn_prefill( total_num_pages = num_pages_per_seq * batch_size kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device) + kv_cache = torch.ones(size=kv_cache_shape, dtype=torch.bfloat16).to(device) kv_cache = kv_cache.as_strided( kv_cache.shape, ( @@ -178,5 +178,202 @@ def test_cudnn_prefill( ) output_ref = wrapper.run(q, kv_cache) - torch.testing.assert_close(output, output_ref, atol=3e-3, rtol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("s_qo", [8, 17, 700]) +@pytest.mark.parametrize("s_kv", [8, 32, 1066]) +@pytest.mark.parametrize("page_size", [8, 16, 64]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("return_lse", [True, False]) +@pytest.mark.parametrize("is_cuda_graph_compatible", [True]) +def test_cudnn_prefill_fp8( + batch_size, + s_qo, + s_kv, + page_size, + num_kv_heads, + num_qo_heads, + causal, + return_lse, + is_cuda_graph_compatible, +): + head_dim = 128 + if s_qo > s_kv: + pytest.skip("s_qo > s_kv, skipping test") + + # test set up basics + seed = 1 + torch.manual_seed(seed) + device = "cuda:0" + + actual_seq_lens_q = torch.randint( + 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + actual_seq_lens_kv = torch.randint( + s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + + cumsum_s_qo = torch.sum(actual_seq_lens_q) + q = torch.randn( + cumsum_s_qo, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16 + ) + + q_scale = q.amax().item() / 256 + + q_scale = torch.tensor(q_scale, device=device, dtype=torch.float32) + q_fp8 = (q / q_scale).to(torch.float8_e4m3fn) + + q_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0) * head_dim * num_qo_heads, + ] + ).int() + + # Initialize KV Cache + num_pages_per_seq = (s_kv + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + + kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim) + kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device) * 0.05 + kv_cache = kv_cache.as_strided( + kv_cache.shape, + ( + 2 * page_size * num_kv_heads * head_dim, + page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + ), + ) + k_cache_view = kv_cache[:, 0, :, :, :] + v_cache_view = kv_cache[:, 1, :, :, :] + + v_cache = v_cache_view.as_strided( + v_cache_view.shape, + (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + ) + k_cache = k_cache_view.as_strided( + k_cache_view.shape, + (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + ) + + k_scale = k_cache.amax().item() / 256 + v_scale = v_cache.amax().item() / 256 + k_cache_fp8 = (k_cache / k_scale).to(torch.float8_e4m3fn) + v_cache_fp8 = (v_cache / v_scale).to(torch.float8_e4m3fn) + + k_scale_tensor = torch.tensor(k_scale, device=device, dtype=torch.float32) + v_scale_tensor = torch.tensor(v_scale, device=device, dtype=torch.float32) + + kv_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum( + (actual_seq_lens_kv.flatten() + page_size - 1) // page_size, + dim=0, + ), + ] + ).int() + + # kv_indices + kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) + for i in range(len(kv_indptr) - 1): + start_idx = kv_indptr[i] + end_idx = kv_indptr[i + 1] + kv_indices[start_idx:end_idx] = torch.arange( + i * num_pages_per_seq, + i * num_pages_per_seq + (end_idx - start_idx), + device=device, + ) + + # kv_last_page_len + kv_last_page_len = torch.where( + actual_seq_lens_kv.flatten() % page_size == 0, + torch.full((batch_size,), page_size, device=device), + actual_seq_lens_kv.flatten() % page_size, + ).int() + + # Now initialize the page tables + block_tables = torch.tensor( + [ + [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + for i in range(batch_size) + ], + dtype=torch.int, + device=device, + ) + + # Initialize scale + scale = float(1.0 / (head_dim**0.5)) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + + wrapper_cudnn = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD", backend="cudnn" + ) + wrapper_cudnn.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + causal=causal, + q_data_type=torch.float8_e4m3fn, + o_data_type=torch.bfloat16, + seq_lens=actual_seq_lens_kv, + seq_lens_q=actual_seq_lens_q, + sm_scale=scale, + max_token_per_sequence=s_qo, + max_sequence_kv=s_kv, + block_tables=block_tables, + ) + + output = wrapper_cudnn.run( + q_fp8, + (k_cache_fp8, v_cache_fp8), + q_scale=q_scale, + k_scale=k_scale_tensor, + v_scale=v_scale_tensor, + ) + + qo_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + ] + ).int() + + # Workspace buffer + workspace_buffer_ref = torch.empty( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer_ref, "HND", backend="fa2" + ) + wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + causal=causal, + q_data_type=torch.bfloat16, + ) + + output_ref = wrapper.run(q, kv_cache) + + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)