-
Notifications
You must be signed in to change notification settings - Fork 593
Rebase FP8 SM100 Cutlass FMHA Attention to main (original PR#1238) #2047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rebase FP8 SM100 Cutlass FMHA Attention to main (original PR#1238) #2047
Conversation
WalkthroughThis PR adds FP8 (8-bit floating point) quantization support to Blackwell FMHA by introducing per-activation scaling parameters across C++ kernels and Python APIs, derives head/dimension values from tensor shapes instead of explicit parameters, refactors benchmarks to use CLI-driven configuration with CSV export, adjusts register allocations for kernel optimization, and adds comprehensive FP8 testing with relaxed numeric tolerances. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Hi @pavanimajety would you mind also supporting user specified bmm2 scale? |
Signed-off-by: Pavani Majety <[email protected]>
Signed-off-by: Pavani Majety <[email protected]>
97d7f73 to
6697a97
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (5)
include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp (1)
67-70: LGTM! Stage count adjustment for FP8 is appropriate.The logic correctly accounts for FP8's smaller element size (1 byte vs 2 bytes for FP16/BF16), allowing more pipeline stages to fit in shared memory and improve latency hiding. The fallback values (2 for FP8, 1 for non-FP8 with other tile dimensions) are reasonable defaults.
Consider adding a brief comment explaining the staging strategy, e.g.:
static constexpr int StageCountQ = 2; + // FP8 (1-byte) elements allow more buffering stages in shared memory static constexpr int StageCountKV = (sizeof(Element_) == 1) ? (get<2>(TileShapeQK{}) == 128 ? 4 : 2) : (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64 ? 2 : 1);csrc/fmha_cutlass_sm100.cu (1)
54-68: Unreachable code on line 67.The
return false;at line 67 is unreachable because both theifandelsebranches return. This is dead code.Remove the unreachable return:
} else { \ return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \ using c_type_out = nv_bfloat16; \ return __VA_ARGS__(); \ }); \ } \ - return false; \ }()tests/attention/test_blackwell_fmha.py (1)
484-495: Main block test parameters don't match pytest parameterization.The
__main__block usesnum_qo_heads=1, num_kv_heads=1, sm_scale=1but the pytest parameterization uses(128, 128)for heads and1.0 / math.sqrt(192)for sm_scale. This inconsistency could cause confusion during manual debugging.Consider aligning the main block parameters with the parameterized values:
if __name__ == "__main__": test_blackwell_cutlass_fmha_fp8( batch_size=9, qo_len=377, kv_len=977, - num_qo_heads=1, - num_kv_heads=1, + num_qo_heads=128, + num_kv_heads=128, head_dim_qk=192, head_dim_vo=128, - sm_scale=1, + sm_scale=1.0 / math.sqrt(192), causal=False, )benchmarks/bench_blackwell_attention.py (2)
29-39:o_data_typeparameter is redundant and can be removed for clarityInside
bench_fmha_blackwell, theo_data_typeargument is immediately overwritten based solely ondtype, and the call sites always pass the same effective value. This makes the parameter misleading noise in the API.Consider simplifying the function signature and calls like this:
-def bench_fmha_blackwell( - batch_size, - qkv_len, - num_qo_heads, - num_kv_heads, - head_dim_qk, - head_dim_vo, - causal, - dtype, - o_data_type, -): +def bench_fmha_blackwell( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo, + causal, + dtype, +): @@ - # For FP8 input, output must be bfloat16 - o_data_type = torch.bfloat16 if dtype.itemsize == 1 else dtype + # For FP8 input, output must be bfloat16 + o_data_type = torch.bfloat16 if dtype.itemsize == 1 else dtype @@ - result_bf16 = bench_fmha_blackwell( + result_bf16 = bench_fmha_blackwell( batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, causal, - torch.bfloat16, - o_data_type=torch.bfloat16, + torch.bfloat16, ) @@ - result_fp8 = bench_fmha_blackwell( + result_fp8 = bench_fmha_blackwell( batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, causal, - torch.float8_e4m3fn, - o_data_type=torch.bfloat16, + torch.float8_e4m3fn, )This keeps the “FP8 → BF16 output” rule in one place and makes the public API less confusing.
Also applies to: 64-66, 145-153, 158-168, 179-189
197-217: Optional: Normalizedtypefor easier CSV post-processingRight now
result["dtype"]is a dtype object, whichcsv.DictWriterwill stringify (e.g.,torch.bfloat16,torch.float8_e4m3fn). That’s valid CSV, but downstream parsing may be simpler if you store a compact, stable string like"bf16"/"fp8_e4m3fn"instead.For example, just before appending results you could normalize:
for result in results: result["dtype"] = str(result["dtype"]).replace("torch.", "") writer.writerow(result)This keeps the benchmark CSV a bit more tool-friendly if you analyze it later with pandas, spreadsheets, etc.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
benchmarks/bench_blackwell_attention.py(4 hunks)csrc/fmha_cutlass_sm100.cu(3 hunks)csrc/fmha_cutlass_sm100_binding.cu(1 hunks)flashinfer/prefill.py(10 hunks)include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp(1 hunks)include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh(3 hunks)include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp(1 hunks)tests/attention/test_blackwell_fmha.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_blackwell_fmha.py (2)
flashinfer/utils.py (2)
is_sm100a_supported(531-533)is_sm110a_supported(541-543)flashinfer/prefill.py (10)
wrapper(135-151)BatchPrefillWithRaggedKVCacheWrapper(2316-3111)plan(1594-1982)plan(2552-2849)run(2013-2025)run(2028-2040)run(2043-2267)run(2879-2889)run(2892-2902)run(2905-3080)
csrc/fmha_cutlass_sm100.cu (2)
include/flashinfer/attention/blackwell/device/sm100_mla.hpp (4)
stream(324-324)stream(324-324)stream(328-328)stream(328-328)include/flashinfer/attention/blackwell/device/fmha.hpp (4)
stream(240-240)stream(240-240)stream(244-244)stream(244-244)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (18)
include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh (3)
79-80: LGTM! Per-activation scale parameters added consistently.The new
q_scale,k_scale,v_scale, ando_scaleparameters are properly added to support FP8 quantization. The parameter ordering is consistent with the binding layer.
122-127: Operation::Arguments construction correctly propagates all scales.The scales are passed to the mainloop arguments alongside the existing
sm_scale, enabling FP8 calibration in the CUTLASS kernel.
167-177: Wrapper function signature and forwarding are consistent.The
run_fmha_fwdwrapper correctly forwards all scale parameters toFwdRunner::run, maintaining API consistency.flashinfer/prefill.py (8)
2911-2914: FP8 scale parameters added to run() signature.The new
q_scale,k_scale,v_scale,o_scaleparameters enable per-activation scaling for FP8 quantization. Good placement after existing positional args.
2933-2940: Clear documentation for FP8 scale parameters.The docstrings accurately describe the purpose of each scale parameter and their default behavior.
2988-2994: Correct FP8-to-bf16 output handling.Using
q.dtype.itemsize == 1to detect FP8 inputs and defaulting output to bf16 is the established pattern in this codebase.
3013-3016: Scale parameters correctly forwarded to fmha_varlen.The scales are passed through consistently to the underlying CUTLASS implementation.
3198-3214: fmha_varlen signature consistently updated across all overloads.The three function signatures (two
@overloaddeclarations and the implementation) are correctly synchronized with the new scale parameters.
3238-3245: Appropriate default scale values.Setting all scales to
1.0whenNoneprovides correct no-op behavior for non-FP8 paths and backward compatibility.
3264-3272: Consistent FP8 output dtype handling in fmha_varlen.The bf16 output for FP8 inputs matches the pattern used elsewhere in this file.
3294-3298: Scale parameters correctly passed to module.run.The parameter order (sm_scale, q_scale, k_scale, v_scale, o_scale, max_qo_len) matches the C++ binding signature.
csrc/fmha_cutlass_sm100_binding.cu (1)
20-26: Binding signature correctly updated for FP8 scale parameters.The signature replacement of head/dim parameters with scale parameters (scale_q, scale_k, scale_v, o_scale) is consistent with the implementation file. Head/dim values are now derived from tensor shapes at runtime.
csrc/fmha_cutlass_sm100.cu (2)
88-99: Clean API simplification by deriving dimensions from tensor shapes.Deriving
num_qo_heads,num_kv_heads,head_dim_qk, andhead_dim_vofrom tensor dimensions (q.size(1), k.size(1), q.size(2), v.size(2)) reduces API surface and eliminates potential mismatches between declared and actual dimensions.
122-134: Scale parameters correctly forwarded to run_fmha_fwd.The call correctly passes all scale parameters followed by the derived head/dimension values, maintaining consistency with the header file signature.
tests/attention/test_blackwell_fmha.py (3)
350-369: Good FP8 test coverage for DeepSeek-R1 configuration.The parameterization targets the specific DeepSeek-R1 configuration (128 heads, 192/128 head dims) which is the primary FP8 use case for this PR.
393-428: Well-structured FP8 test setup.The test correctly:
- Creates FP8 inputs by converting from half precision
- Configures the wrapper with explicit
o_data_type=dtype_out- Verifies output dtype assertion at line 431
437-481: Correct FP8 reference implementation with appropriate tolerances.The reference correctly:
- Upcasts FP8 to float32 for computation accuracy
- Converts output to bf16 matching kernel output
- Uses relaxed tolerances (5e-2) appropriate for FP8 quantization error
benchmarks/bench_blackwell_attention.py (1)
86-94:and
| return { | ||
| "config_name": f"Blackwell-{config_name}", | ||
| "batch_size": batch_size, | ||
| "qkv_len": qkv_len, | ||
| "num_qo_heads": num_qo_heads, | ||
| "num_kv_heads": num_kv_heads, | ||
| "head_dim_qk": head_dim_qk, | ||
| "head_dim_vo": head_dim_vo, | ||
| "causal": causal, | ||
| "dtype": dtype, | ||
| "time_ms": ms, | ||
| "tflops": TFLOPS, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined config_name in return dict will crash benchmarks
bench_fmha_blackwell uses config_name in the returned dict, but config_name is neither a parameter nor a global; the first call will raise NameError at the return line. The caller already sets result["config_name"] = config_name after the call, so the cleanest fix is to drop this field from the function’s return value.
Apply this diff inside bench_fmha_blackwell:
- return {
- "config_name": f"Blackwell-{config_name}",
- "batch_size": batch_size,
+ return {
+ "batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": dtype,
"time_ms": ms,
"tflops": TFLOPS,
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return { | |
| "config_name": f"Blackwell-{config_name}", | |
| "batch_size": batch_size, | |
| "qkv_len": qkv_len, | |
| "num_qo_heads": num_qo_heads, | |
| "num_kv_heads": num_kv_heads, | |
| "head_dim_qk": head_dim_qk, | |
| "head_dim_vo": head_dim_vo, | |
| "causal": causal, | |
| "dtype": dtype, | |
| "time_ms": ms, | |
| "tflops": TFLOPS, | |
| } | |
| return { | |
| "batch_size": batch_size, | |
| "qkv_len": qkv_len, | |
| "num_qo_heads": num_qo_heads, | |
| "num_kv_heads": num_kv_heads, | |
| "head_dim_qk": head_dim_qk, | |
| "head_dim_vo": head_dim_vo, | |
| "causal": causal, | |
| "dtype": dtype, | |
| "time_ms": ms, | |
| "tflops": TFLOPS, | |
| } |
🤖 Prompt for AI Agents
In benchmarks/bench_blackwell_attention.py around lines 98 to 110, the return
dict from bench_fmha_blackwell references an undefined variable config_name
which will raise a NameError; remove the "config_name" key from the returned
dict (since the caller sets result["config_name"] already) so the function only
returns the valid fields (batch_size, qkv_len, num_qo_heads, num_kv_heads,
head_dim_qk, head_dim_vo, causal, dtype, time_ms, tflops).
| static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0); | ||
| static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's read the file around lines 69-70 to confirm the current state
head -80 include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp | tail -20Repository: flashinfer-ai/flashinfer
Length of output: 192
🏁 Script executed:
# Check for any comments or documentation nearby that might explain the tuning
sed -n '50,100p' include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hppRepository: flashinfer-ai/flashinfer
Length of output: 192
🏁 Script executed:
# Look for similar register tuning patterns in related FMHA files
fd "fmha.*kernel" include/flashinfer/attention/blackwell/kernel/ --type fRepository: flashinfer-ai/flashinfer
Length of output: 192
Stale comments for updated register budgets
Lines 69-70 have hard-coded both NumRegsCorrection and NumRegsOther to 64, but the trailing comments still describe the old conditional formulas, which is misleading for future maintenance.
Update the comments to clarify the tuning choice:
- static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0);
- static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0);
+ // Tuned register budgets for Correction and Other warp roles for SM100 FMHA.
+ static const int NumRegsCorrection = 64;
+ static const int NumRegsOther = 64;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0); | |
| static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0); | |
| // Tuned register budgets for Correction and Other warp roles for SM100 FMHA. | |
| static const int NumRegsCorrection = 64; | |
| static const int NumRegsOther = 64; |
🤖 Prompt for AI Agents
In
include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
around lines 69-70, the trailing comments next to NumRegsCorrection and
NumRegsOther are stale (they show old conditional formulas) and should be
updated to reflect the current fixed value; replace those misleading comments
with a brief explanation that both values are intentionally fixed to 64 for the
current SM100 register budget/tuning (and note that the previous conditional on
kDebugUsingPrintf was removed), or remove the comments entirely if unnecessary,
ensuring the comment clearly documents the tuning choice and rationale for
future maintainers.
|
@yzh119 Sorry for the delayed follow up on this. I updated similar to fa3 calls and cutlass Arguments for q_scale, k_scale, v_scale and o_scale in sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp#L173-L184 Could you please review and provide feedback? thanks! |
📌 Description
Just does a refresh of the FP8 Attention and adds benchmarks for Deepseek FMHA sizes. Original PR - #1238
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.