Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 148 additions & 53 deletions benchmarks/bench_blackwell_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
limitations under the License.
"""

import argparse
import csv
import numpy as np
import torch

Expand All @@ -27,20 +29,26 @@
def bench_fmha_blackwell(
batch_size,
qkv_len,
num_heads,
head_dim,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo,
causal,
dtype,
o_data_type,
):
# if sizeof(dtype) == 1 like with torch.float8_e4m3fn,
# create randn from half and then convert to dtype
init_dtype = torch.half if dtype.itemsize == 1 else dtype
q = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=init_dtype, device="cuda"
).to(dtype)
k = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=init_dtype, device="cuda"
).to(dtype)
v = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=init_dtype, device="cuda"
).to(dtype)

qo_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
Expand All @@ -53,16 +61,19 @@ def bench_fmha_blackwell(
kv_layout="NHD",
backend="cutlass",
)
# For FP8 input, output must be bfloat16
o_data_type = torch.bfloat16 if dtype.itemsize == 1 else dtype
wrapper.plan(
qo_segment_offsets,
kv_segment_offsets,
num_heads,
num_heads,
head_dim,
head_dim_vo=head_dim,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo=head_dim_vo,
causal=causal,
q_data_type=dtype,
kv_data_type=dtype,
o_data_type=o_data_type,
)
_o = wrapper.run(q, k, v)
measurements = bench_gpu_time(
Expand All @@ -75,52 +86,136 @@ def bench_fmha_blackwell(
TFLOPS = attention_tflops_per_sec_with_actual_seq_lens(
torch.full((batch_size,), qkv_len),
torch.full((batch_size,), qkv_len),
head_dim,
head_dim,
num_heads,
head_dim_qk,
head_dim_vo,
num_qo_heads,
causal,
ms,
)
print(
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s"
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim_qk={head_dim_qk}, head_dim_vo={head_dim_vo}, causal={causal}), flops: {TFLOPS:.3f} TFLOPs/s"
)
return {
"config_name": f"Blackwell-{config_name}",
"batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": dtype,
"time_ms": ms,
"tflops": TFLOPS,
}
Comment on lines +98 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Undefined config_name in return dict will crash benchmarks

bench_fmha_blackwell uses config_name in the returned dict, but config_name is neither a parameter nor a global; the first call will raise NameError at the return line. The caller already sets result["config_name"] = config_name after the call, so the cleanest fix is to drop this field from the function’s return value.

Apply this diff inside bench_fmha_blackwell:

-    return {
-        "config_name": f"Blackwell-{config_name}",
-        "batch_size": batch_size,
+    return {
+        "batch_size": batch_size,
         "qkv_len": qkv_len,
         "num_qo_heads": num_qo_heads,
         "num_kv_heads": num_kv_heads,
         "head_dim_qk": head_dim_qk,
         "head_dim_vo": head_dim_vo,
         "causal": causal,
         "dtype": dtype,
         "time_ms": ms,
         "tflops": TFLOPS,
     }
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return {
"config_name": f"Blackwell-{config_name}",
"batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": dtype,
"time_ms": ms,
"tflops": TFLOPS,
}
return {
"batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": dtype,
"time_ms": ms,
"tflops": TFLOPS,
}
πŸ€– Prompt for AI Agents
In benchmarks/bench_blackwell_attention.py around lines 98 to 110, the return
dict from bench_fmha_blackwell references an undefined variable config_name
which will raise a NameError; remove the "config_name" key from the returned
dict (since the caller sets result["config_name"] already) so the function only
returns the valid fields (batch_size, qkv_len, num_qo_heads, num_kv_heads,
head_dim_qk, head_dim_vo, causal, dtype, time_ms, tflops).



if __name__ == "__main__":
print("\n === head_dim=128 ===")
bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16)

bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16)

print("\n === head_dim=64 ===")
bench_fmha_blackwell(128, 512, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 64, False, torch.bfloat16)

bench_fmha_blackwell(128, 512, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 64, True, torch.bfloat16)
parser = argparse.ArgumentParser(
description="Benchmark FP8 attention for DeepSeek-R1"
)
parser.add_argument(
"--save-results-to",
type=str,
default=None,
help="Path to save benchmark results as CSV (optional)",
)
args = parser.parse_args()

results = []

# Define configurations: (batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name)
# DeepSeek-R1 uses MLA (Multi-head Latent Attention) with 128 heads
# head_dim_qk=192 (128 nope + 64 rope), head_dim_vo=128
configs = [
(16, 512, 128, 128, 192, 128, "DeepSeek-R1"),
(8, 1024, 128, 128, 192, 128, "DeepSeek-R1"),
(4, 2048, 128, 128, 192, 128, "DeepSeek-R1"),
(2, 4096, 128, 128, 192, 128, "DeepSeek-R1"),
(1, 8192, 128, 128, 192, 128, "DeepSeek-R1"),
]

# Run benchmarks: Causal first, then non-causal
# For each config: bfloat16 then fp8
for causal in [True, False]:
print(f"\n{'=' * 80}")
print(f"Running {'CAUSAL' if causal else 'NON-CAUSAL'} benchmarks")
print(f"{'=' * 80}")

for (
batch_size,
qkv_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo,
config_name,
) in configs:
# Run bfloat16
print(
f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16"
)
result_bf16 = bench_fmha_blackwell(
batch_size,
qkv_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo,
causal,
torch.bfloat16,
o_data_type=torch.bfloat16,
)
result_bf16["config_name"] = config_name
results.append(result_bf16)
print(
f" β†’ {result_bf16['tflops']:.2f} TFLOPs/s, {result_bf16['time_ms']:.3f} ms"
)

# Run fp8
print(
f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8"
)
result_fp8 = bench_fmha_blackwell(
batch_size,
qkv_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo,
causal,
torch.float8_e4m3fn,
o_data_type=torch.bfloat16,
)
result_fp8["config_name"] = config_name
results.append(result_fp8)
speedup = result_fp8["tflops"] / result_bf16["tflops"]
print(
f" β†’ {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)"
)

# Write results to CSV if requested
if args.save_results_to:
fieldnames = [
"config_name",
"batch_size",
"qkv_len",
"num_qo_heads",
"num_kv_heads",
"head_dim_qk",
"head_dim_vo",
"causal",
"dtype",
"time_ms",
"tflops",
]

with open(args.save_results_to, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for result in results:
writer.writerow(result)

print(f"\n{'=' * 80}")
print(f"Results saved to: {args.save_results_to}")
print(f"{'=' * 80}")
19 changes: 14 additions & 5 deletions csrc/fmha_cutlass_sm100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ using tvm::ffi::Optional;
using c_type_out = c_type_in; \
return __VA_ARGS__(); \
}); \
} else { \
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \
using c_type_out = nv_bfloat16; \
return __VA_ARGS__(); \
}); \
} \
return false; \
}()
Expand All @@ -80,14 +85,18 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff
ffi::TensorView qo_tile_indices, ffi::TensorView qo_head_indices,
ffi::TensorView batch_indices, ffi::TensorView o,
Optional<ffi::TensorView> maybe_lse, int64_t mask_mode_code,
double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len) {
double sm_scale, double scale_q, double scale_k, double scale_v,
double o_scale, int64_t max_qo_len) {
TVM_FFI_ICHECK_EQ(q.dtype(), k.dtype());
auto scalar_type_in = q.dtype();
auto scalar_type_out = o.dtype();
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
int total_qo_len = q.size(0);
int total_kv_len = k.size(0);
int num_qo_heads = q.size(1);
int num_kv_heads = k.size(1);
int head_dim_qk = q.size(2);
int head_dim_vo = v.size(2);
int batch_size = qo_segment_offsets.size(0) - 1;
int q_stride_n = q.stride(0);
int q_stride_h = q.stride(1);
Expand Down Expand Up @@ -120,9 +129,9 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff
static_cast<int*>(qo_head_indices.data_ptr()), static_cast<int*>(batch_indices.data_ptr()),
static_cast<cutlass_type_out*>(o.data_ptr()),
maybe_lse.has_value() ? static_cast<float*>(maybe_lse.value().data_ptr()) : nullptr,
mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n,
q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len,
total_kv_len, max_qo_len, stream);
mask_mode_code, sm_scale, scale_q, scale_k, scale_v, o_scale, num_qo_heads, num_kv_heads,
head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n,
v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, stream);
TVM_FFI_ICHECK_EQ(status, cudaSuccess)
<< "Cutlass FMHA forward pass failed" << cudaGetErrorString(status);

Expand Down
4 changes: 2 additions & 2 deletions csrc/fmha_cutlass_sm100_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ void FMHACutlassSM100Run(TensorView workspace_buffer, TensorView q, TensorView k
TensorView work_indptr, TensorView qo_tile_indices,
TensorView qo_head_indices, TensorView batch_indices, TensorView o,
Optional<TensorView> maybe_lse, int64_t mask_mode_code, double sm_scale,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_vo, int64_t max_qo_len);
double scale_q, double scale_k, double scale_v, double o_scale,
int64_t max_qo_len);

void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets,
TensorView work_indptr, TensorView qo_tile_indices,
Expand Down
Loading