Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ The output CSV will contain detailed metrics including:
| `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) |
| `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. |
| `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. |
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-native, cublas|
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cublas|

### Attention Flags
| Flag | Description |
Expand Down
206 changes: 173 additions & 33 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
import torch

import flashinfer

# Try to import cudnn for version checking
CUDNN_AVAILABLE = False
CUDNN_BACKEND_VERSION = 0
try:
import cudnn

CUDNN_AVAILABLE = True
CUDNN_BACKEND_VERSION = cudnn.backend_version()
except ImportError:
pass
except OSError as e:
error_msg = str(e).lower()
is_lib_missing = any(ext in error_msg for ext in [".so", ".dll"])
if not is_lib_missing:
raise
from flashinfer.testing.utils import (
attention_tb_per_sec_with_actual_seq_lens,
attention_tflops_per_sec_with_actual_seq_lens,
Expand Down Expand Up @@ -88,6 +104,7 @@ def parse_attention_args(line, parser):
"fa2_tc",
"fa3",
"cudnn",
"cudnn-native",
"cutlass",
"trtllm-gen",
"trtllm-native",
Expand Down Expand Up @@ -680,6 +697,14 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
print(f"[ERROR] Unsupported kv_dtype: {args.kv_dtype}")
return res

# Increase tolerances for FP8 due to lower precision
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
rtol = 5e-1 # Relaxed relative tolerance for FP8
atol = 1e-1 # Relaxed absolute tolerance for FP8

# Parse and validate backend configurations
backends = args.backends
page_size = args.page_size
Expand All @@ -706,15 +731,36 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
backends.remove("fa2")
if "cudnn" in backends:
remove_cudnn = False
# cuDNN FP8 prefill requires cuDNN >= 9.18.0 (backend version 91800)
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] cuDNN backend does not support FP8. Skipping.")
remove_cudnn = True
if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91800:
print(
f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.18.0. "
f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn backend."
)
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")

if "cudnn-native" in backends:
remove_cudnn_native = False
# cuDNN-native does not yet support FP8 prefill
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91800:
print(
f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.18.0. "
f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn-native backend."
)
remove_cudnn_native = True
if remove_cudnn_native:
backends.remove("cudnn-native")

if "trtllm-gen" in backends:
remove_trtllm = False
if not causal:
Expand Down Expand Up @@ -908,7 +954,44 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
print(f"[VVERBOSE] {kv_last_page_len.shape = }")
print(f"[VVERBOSE] {scale = }")

# Prepare wrappers
# Helper function to convert to FP8 (matches test_trtllm_gen_attention.py approach)
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()

# Compute scales and convert to FP8 if needed (before creating wrappers)
q_scale, k_scale, v_scale = None, None, None
q_scale_tensor, k_scale_tensor, v_scale_tensor = None, None, None
o_data_type = q_dtype # Default output dtype
# Separate K/V caches for cuDNN (which requires separate tensors, not combined kv_cache)
k_cache_cudnn, v_cache_cudnn = k_cache, v_cache

if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q, q_scale_t = to_float8(q, q_dtype)
q_scale = q_scale_t.item()
q_scale_tensor = q_scale_t.reshape(1, 1, 1, 1)
# o_data_type stays as q_dtype (FP8 output)

if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# Convert k_cache and v_cache to quantized dtype for cuDNN
k_cache_cudnn, k_scale_t = to_float8(k_cache, kv_dtype)
v_cache_cudnn, v_scale_t = to_float8(v_cache, kv_dtype)
k_scale = k_scale_t.item()
v_scale = v_scale_t.item()
k_scale_tensor = k_scale_t.reshape(1, 1, 1, 1)
v_scale_tensor = v_scale_t.reshape(1, 1, 1, 1)

# Also convert the full kv_cache for non-cuDNN backends
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_quantized, _ = to_float8(k_data, kv_dtype)
v_quantized, _ = to_float8(v_data, kv_dtype)
kv_cache = torch.cat([k_quantized, v_quantized], dim=1)

# Prepare wrappers (after FP8 conversion so we have correct dtypes)
backend_wrappers = {}
for backend in backends:
if backend in ["fa2", "fa3", "trtllm-gen"]:
Expand Down Expand Up @@ -941,28 +1024,78 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
kv_data_type=kv_dtype,
block_tables=block_tables,
)

k_scale, v_scale = None, None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_scale = k_data.amax().item() / 256
v_scale = v_data.amax().item() / 256
k_fp8 = (k_data / k_scale).to(kv_dtype)
v_fp8 = (v_data / v_scale).to(kv_dtype)
kv_cache = torch.cat([k_fp8, v_fp8], dim=1)
elif backend == "cudnn":
# cuDNN uses NHD layout and the wrapper API
backend_wrappers[backend] = (
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
"NHD",
backend="cudnn",
)
)
backend_wrappers["cudnn"].plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
page_size,
pos_encoding_mode="NONE",
causal=causal,
q_data_type=q_dtype,
o_data_type=o_data_type,
seq_lens=actual_seq_lens_kv_device,
seq_lens_q=actual_seq_lens_q_device,
sm_scale=scale,
max_token_per_sequence=s_qo,
max_sequence_kv=s_kv,
block_tables=block_tables,
)

def run_backend_wrapper(backend):
if backend in ["fa2", "fa3", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale
q, kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale
)
elif backend == "cudnn":
# cuDNN uses wrapper API with tensor scales for FP8
return backend_wrappers[backend].run(
q,
(k_cache_cudnn, v_cache_cudnn),
q_scale=q_scale_tensor,
k_scale=k_scale_tensor,
v_scale=v_scale_tensor,
)
elif backend == "trtllm-native":
# Compute combined bmm1_scale: q_scale * k_scale * sm_scale
# For FP8: all scales are float values
_q_scale = q_scale if q_scale is not None else 1.0
_k_scale = k_scale if k_scale is not None else 1.0
_v_scale = v_scale if v_scale is not None else 1.0
bmm1_scale = _q_scale * _k_scale * scale
bmm2_scale = _v_scale
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv_device,
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
)
elif backend == "cudnn-native":
# Direct cudnn_batch_prefill_with_kv_cache call (similar to trtllm-native)
return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
q,
k_cache,
v_cache,
k_cache_cudnn,
v_cache_cudnn,
scale,
workspace_buffer,
max_token_per_sequence=s_qo,
Expand All @@ -975,27 +1108,17 @@ def run_backend_wrapper(backend):
is_cuda_graph_compatible=is_cuda_graph_compatible,
batch_offsets_q=q_indptr,
batch_offsets_o=q_indptr,
q_scale=q_scale_tensor,
k_scale=k_scale_tensor,
v_scale=v_scale_tensor,
o_data_type=o_data_type,
)[0]
elif backend == "trtllm-native":
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv_device,
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=scale if k_scale is None else k_scale * scale,
bmm2_scale=1.0 if v_scale is None else v_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
)
else:
print(f"[ERROR] Backend {backend} not supported")
return res

has_reference_output = False
reference_backend = None
# Iterate over each backend:
for cur_backend in backends:
# Clear workspace buffer to prevent unexpected interactions between backends.
Expand All @@ -1005,6 +1128,7 @@ def run_backend_wrapper(backend):
if cur_backend == "fa2":
has_reference_output = True
reference_output = outputs[cur_backend]
reference_backend = "fa2"
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
Expand All @@ -1020,6 +1144,22 @@ def run_backend_wrapper(backend):
# Perform reference check
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())

# When cases where FA2 is not available, try to find an alternative reference
# Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native
if run_refcheck and not has_reference_output and len(tested_backends) > 1:
reference_priority = ["cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"]
for candidate in reference_priority:
if candidate in tested_backends:
has_reference_output = True
reference_backend = candidate
reference_output = outputs[candidate]
if args.verbose >= 1:
print(
f"[INFO] FA2 not available for reference. Using {candidate} as reference backend for cross-comparison."
)
break

if len(tested_backends) > 1:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
Expand All @@ -1037,7 +1177,7 @@ def run_backend_wrapper(backend):
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
if num_different_elements > 0:
print(
f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: "
f"[ERROR] Output tensor mismatch between backends {reference_backend} and {tested_backends[i]}: "
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
)
if not args.allow_output_mismatch:
Expand Down
15 changes: 8 additions & 7 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,15 @@ def dtype_str_to_torch_dtype(dtype_str):
},
"BatchPrefillWithPagedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "cudnn"],
"8.6": ["fa2", "cudnn"],
"8.9": ["fa2", "cudnn"],
"9.0": ["fa2", "fa3", "cudnn"],
"10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "cudnn"],
"8.0": ["fa2", "cudnn", "cudnn-native"],
"8.6": ["fa2", "cudnn", "cudnn-native"],
"8.9": ["fa2", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
"10.0": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "cudnn", "cudnn-native"],
},
"BatchPrefillWithRaggedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_ragged_attention_deepseek
Expand Down
Loading