Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ The output CSV will contain detailed metrics including:
| `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) |
| `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. |
| `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. |
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cutlass, trtllm, trtllm-gen, trtllm-native, cublas|
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cublas|

### Attention Flags
| Flag | Description |
Expand Down
206 changes: 173 additions & 33 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
import torch

import flashinfer

# Try to import cudnn for version checking
CUDNN_AVAILABLE = False
CUDNN_BACKEND_VERSION = 0
try:
import cudnn

CUDNN_AVAILABLE = True
CUDNN_BACKEND_VERSION = cudnn.backend_version()
except ImportError:
pass
except OSError as e:
error_msg = str(e).lower()
is_lib_missing = any(ext in error_msg for ext in [".so", ".dll"])
if not is_lib_missing:
raise
from flashinfer.testing.utils import (
attention_tb_per_sec_with_actual_seq_lens,
attention_tflops_per_sec_with_actual_seq_lens,
Expand Down Expand Up @@ -88,6 +104,7 @@ def parse_attention_args(line, parser):
"fa2_tc",
"fa3",
"cudnn",
"cudnn-native",
"cutlass",
"trtllm-gen",
"trtllm-native",
Expand Down Expand Up @@ -680,6 +697,14 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
print(f"[ERROR] Unsupported kv_dtype: {args.kv_dtype}")
return res

# Increase tolerances for FP8 due to lower precision
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
rtol = 5e-1 # Relaxed relative tolerance for FP8
atol = 1e-1 # Relaxed absolute tolerance for FP8

# Parse and validate backend configurations
backends = args.backends
page_size = args.page_size
Expand All @@ -706,15 +731,36 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
backends.remove("fa2")
if "cudnn" in backends:
remove_cudnn = False
# cuDNN FP8 prefill requires cuDNN >= 9.18.0 (backend version 91800)
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] cuDNN backend does not support FP8. Skipping.")
remove_cudnn = True
if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91800:
print(
f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.18.0. "
f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn backend."
)
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")

if "cudnn-native" in backends:
remove_cudnn_native = False
# cuDNN-native does not yet support FP8 prefill
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91800:
print(
f"[INFO] cuDNN FP8 prefill requires cuDNN >= 9.18.0. "
f"Current version: {CUDNN_BACKEND_VERSION}. Skipping cudnn-native backend."
)
remove_cudnn_native = True
if remove_cudnn_native:
backends.remove("cudnn-native")

if "trtllm-gen" in backends:
remove_trtllm = False
if not causal:
Expand Down Expand Up @@ -908,7 +954,44 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
print(f"[VVERBOSE] {kv_last_page_len.shape = }")
print(f"[VVERBOSE] {scale = }")

# Prepare wrappers
# Helper function to convert to FP8 (matches test_trtllm_gen_attention.py approach)
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()

# Compute scales and convert to FP8 if needed (before creating wrappers)
q_scale, k_scale, v_scale = None, None, None
q_scale_tensor, k_scale_tensor, v_scale_tensor = None, None, None
o_data_type = q_dtype # Default output dtype
# Separate K/V caches for cuDNN (which requires separate tensors, not combined kv_cache)
k_cache_cudnn, v_cache_cudnn = k_cache, v_cache

if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q, q_scale_t = to_float8(q, q_dtype)
q_scale = q_scale_t.item()
q_scale_tensor = q_scale_t.reshape(1, 1, 1, 1)
# o_data_type stays as q_dtype (FP8 output)

if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# Convert k_cache and v_cache to quantized dtype for cuDNN
k_cache_cudnn, k_scale_t = to_float8(k_cache, kv_dtype)
v_cache_cudnn, v_scale_t = to_float8(v_cache, kv_dtype)
k_scale = k_scale_t.item()
v_scale = v_scale_t.item()
k_scale_tensor = k_scale_t.reshape(1, 1, 1, 1)
v_scale_tensor = v_scale_t.reshape(1, 1, 1, 1)

# Also convert the full kv_cache for non-cuDNN backends
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_quantized, _ = to_float8(k_data, kv_dtype)
v_quantized, _ = to_float8(v_data, kv_dtype)
kv_cache = torch.cat([k_quantized, v_quantized], dim=1)

# Prepare wrappers (after FP8 conversion so we have correct dtypes)
backend_wrappers = {}
for backend in backends:
if backend in ["fa2", "fa3", "trtllm-gen"]:
Expand Down Expand Up @@ -941,28 +1024,78 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
kv_data_type=kv_dtype,
block_tables=block_tables,
)

k_scale, v_scale = None, None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_scale = k_data.amax().item() / 256
v_scale = v_data.amax().item() / 256
k_fp8 = (k_data / k_scale).to(kv_dtype)
v_fp8 = (v_data / v_scale).to(kv_dtype)
kv_cache = torch.cat([k_fp8, v_fp8], dim=1)
elif backend == "cudnn":
# cuDNN uses NHD layout and the wrapper API
backend_wrappers[backend] = (
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
"NHD",
backend="cudnn",
)
)
backend_wrappers["cudnn"].plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
page_size,
pos_encoding_mode="NONE",
causal=causal,
q_data_type=q_dtype,
o_data_type=o_data_type,
seq_lens=actual_seq_lens_kv_device,
seq_lens_q=actual_seq_lens_q_device,
sm_scale=scale,
max_token_per_sequence=s_qo,
max_sequence_kv=s_kv,
block_tables=block_tables,
)

def run_backend_wrapper(backend):
if backend in ["fa2", "fa3", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale
q, kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale
)
elif backend == "cudnn":
# cuDNN uses wrapper API with tensor scales for FP8
return backend_wrappers[backend].run(
q,
(k_cache_cudnn, v_cache_cudnn),
q_scale=q_scale_tensor,
k_scale=k_scale_tensor,
v_scale=v_scale_tensor,
)
elif backend == "trtllm-native":
# Compute combined bmm1_scale: q_scale * k_scale * sm_scale
# For FP8: all scales are float values
_q_scale = q_scale if q_scale is not None else 1.0
_k_scale = k_scale if k_scale is not None else 1.0
_v_scale = v_scale if v_scale is not None else 1.0
bmm1_scale = _q_scale * _k_scale * scale
bmm2_scale = _v_scale
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv_device,
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
)
elif backend == "cudnn-native":
# Direct cudnn_batch_prefill_with_kv_cache call (similar to trtllm-native)
return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
q,
k_cache,
v_cache,
k_cache_cudnn,
v_cache_cudnn,
scale,
workspace_buffer,
max_token_per_sequence=s_qo,
Expand All @@ -975,27 +1108,17 @@ def run_backend_wrapper(backend):
is_cuda_graph_compatible=is_cuda_graph_compatible,
batch_offsets_q=q_indptr,
batch_offsets_o=q_indptr,
q_scale=q_scale_tensor,
k_scale=k_scale_tensor,
v_scale=v_scale_tensor,
o_data_type=o_data_type,
)[0]
elif backend == "trtllm-native":
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv_device,
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=scale if k_scale is None else k_scale * scale,
bmm2_scale=1.0 if v_scale is None else v_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
)
else:
print(f"[ERROR] Backend {backend} not supported")
return res

has_reference_output = False
reference_backend = None
# Iterate over each backend:
for cur_backend in backends:
# Clear workspace buffer to prevent unexpected interactions between backends.
Expand All @@ -1005,6 +1128,7 @@ def run_backend_wrapper(backend):
if cur_backend == "fa2":
has_reference_output = True
reference_output = outputs[cur_backend]
reference_backend = "fa2"
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
Expand All @@ -1020,6 +1144,22 @@ def run_backend_wrapper(backend):
# Perform reference check
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())

# When cases where FA2 is not available, try to find an alternative reference
# Priority: cudnn > cudnn-native > trtllm-gen > trtllm-native
if run_refcheck and not has_reference_output and len(tested_backends) > 1:
reference_priority = ["cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"]
for candidate in reference_priority:
if candidate in tested_backends:
has_reference_output = True
reference_backend = candidate
reference_output = outputs[candidate]
if args.verbose >= 1:
print(
f"[INFO] FA2 not available for reference. Using {candidate} as reference backend for cross-comparison."
)
break

if len(tested_backends) > 1:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
Expand All @@ -1037,7 +1177,7 @@ def run_backend_wrapper(backend):
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
if num_different_elements > 0:
print(
f"[ERROR] Output tensor mismatch between backends fa2 and {tested_backends[i]}: "
f"[ERROR] Output tensor mismatch between backends {reference_backend} and {tested_backends[i]}: "
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
)
if not args.allow_output_mismatch:
Expand Down
15 changes: 8 additions & 7 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,15 @@ def dtype_str_to_torch_dtype(dtype_str):
},
"BatchPrefillWithPagedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_batch_context_with_kv_cache
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "cudnn"],
"8.6": ["fa2", "cudnn"],
"8.9": ["fa2", "cudnn"],
"9.0": ["fa2", "fa3", "cudnn"],
"10.0": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "cudnn", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "cudnn"],
"8.0": ["fa2", "cudnn", "cudnn-native"],
"8.6": ["fa2", "cudnn", "cudnn-native"],
"8.9": ["fa2", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
"10.0": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"10.3": ["fa2", "cudnn", "cudnn-native", "trtllm-gen", "trtllm-native"],
"12.0": ["fa2", "cudnn", "cudnn-native"],
},
"BatchPrefillWithRaggedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_ragged_attention_deepseek
Expand Down
Loading