Skip to content

Conversation

@pavanimajety
Copy link
Contributor

@pavanimajety pavanimajety commented Nov 5, 2025

📌 Description

Just does a refresh of the FP8 Attention and adds benchmarks for Deepseek FMHA sizes. Original PR - #1238

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added FP8 quantization support to attention operations with independent per-activation scaling for query, key, value, and output tensors
    • Implemented configurable CLI-based benchmarking tool with CSV result export capability
    • Extended support for multi-head and grouped multi-head-key-value attention configurations
  • Tests

    • Added comprehensive FP8-specific test coverage for attention operations with appropriate precision handling

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 5, 2025

Walkthrough

This PR adds FP8 (8-bit floating point) quantization support to Blackwell FMHA by introducing per-activation scaling parameters across C++ kernels and Python APIs, derives head/dimension values from tensor shapes instead of explicit parameters, refactors benchmarks to use CLI-driven configuration with CSV export, adjusts register allocations for kernel optimization, and adds comprehensive FP8 testing with relaxed numeric tolerances.

Changes

Cohort / File(s) Summary
Benchmark Refactoring
benchmarks/bench_blackwell_attention.py
Replaced single-head configuration with multi-head (num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo), added o_data_type parameter, refactored main block from hard-coded invocations to CLI-driven flow with argparse and CSV export support, updated return value to dictionary containing config metadata and performance metrics.
C++ Kernel Signature Updates
csrc/fmha_cutlass_sm100.cu, csrc/fmha_cutlass_sm100_binding.cu, include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh
Replaced explicit head/dimension parameters (num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo) with per-activation FP8 scaling parameters (scale_q, scale_k, scale_v, o_scale) in FMHACutlassSM100Run, FwdRunner::run, and run_fmha_fwd signatures; derive head/dimension values from tensor shapes internally; added FP8 path fallback in dtype dispatch with nv_bfloat16 output.
Python API FP8 Scaling
flashinfer/prefill.py
Added q_scale, k_scale, v_scale, o_scale as optional parameters to BatchPrefillWithPagedKVCacheWrapper.run, fmha_varlen, and single_prefill_with_kv_cache.run; default scales to 1.0 when None; updated output dtype logic to use bfloat16 for FP8 inputs; propagated scales through all attention kernel invocation paths.
Kernel Register & StageCount Tuning
include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp, include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
Simplified register allocation constants: NumRegsCorrection (96 - debug offset → 64) and NumRegsOther (32 + debug offset → 64); reworked StageCountKV calculation to depend on both element size and tile shape instead of tile shape alone.
FP8 Test Coverage
tests/attention/test_blackwell_fmha.py
Added new parametric test_blackwell_cutlass_fmha_fp8 function covering batch_size, sequence lengths, head counts, head dimensions, and causal masking; constructs FP8 inputs, validates output dtype, computes reference via upcast to float32 with masking and log-sum-exp finalization, asserts results with FP8-appropriate tolerances.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Signature propagation verification: Multiple parameter changes across Python API → C++ bindings → headers → kernel implementations require tracing all call sites for consistency (csrc files, include headers, prefill.py).
  • FP8 scaling logic: New per-activation scaling path in prefill.py with default handling and dtype selection logic needs careful verification of correctness across causal and non-causal branches.
  • Kernel register/StageCount tuning: Changes to NumRegsCorrection, NumRegsOther, and StageCountKV calculation are performance-critical; conditional logic based on element size and tile shape requires validation.
  • Test logic complexity: FP8 test reference computation includes upcast, masking, and lse finalization; relaxed tolerances require justification.

Possibly related PRs

  • #2081: Adds and propagates output scale parameters (o_scale/rcpOutScale) in decode/xqa paths, directly related to per-activation FP8 scaling infrastructure.
  • #2111: Adds per-activation FP8 scaling parameters (q_scale, k_scale, v_scale, o_scale) across FMHA signatures and call sites, shares identical API surface changes with this PR.

Suggested reviewers

  • joker-eph
  • aleozlx
  • djmmoss
  • yzh119
  • yongwww
  • cyx-6

Poem

🐰 With eight bits of focus, we quantize with care,
Scale-factors precise through the attention layer,
From Python to kernels, the FP8 way,
Blackwell's fast path now has its FP8 day! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The description is largely incomplete: it lacks detailed explanation of changes, rationale, or implementation details. While it references the original PR and benchmark additions, it provides minimal substance beyond these two points and leaves most checklist items unchecked. Expand the description with specific changes made (e.g., per-head FP8 scales, signature updates, register count adjustments), explain why these changes are needed, and mark completed checklist items if applicable.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly indicates this is a rebase of FP8 SM100 Cutlass FMHA Attention work; it directly corresponds to the file changes and objectives of adding FP8 support and benchmarks.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 5, 2025

Hi @pavanimajety would you mind also supporting user specified bmm2 scale?

@pavanimajety pavanimajety force-pushed the fp8-attention-cutlass-fmha branch from 97d7f73 to 6697a97 Compare December 9, 2025 00:43
@pavanimajety pavanimajety marked this pull request as ready for review December 9, 2025 00:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp (1)

67-70: LGTM! Stage count adjustment for FP8 is appropriate.

The logic correctly accounts for FP8's smaller element size (1 byte vs 2 bytes for FP16/BF16), allowing more pipeline stages to fit in shared memory and improve latency hiding. The fallback values (2 for FP8, 1 for non-FP8 with other tile dimensions) are reasonable defaults.

Consider adding a brief comment explaining the staging strategy, e.g.:

 static constexpr int StageCountQ = 2;
+  // FP8 (1-byte) elements allow more buffering stages in shared memory
   static constexpr int StageCountKV =
       (sizeof(Element_) == 1)
           ? (get<2>(TileShapeQK{}) == 128 ? 4 : 2)
           : (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64 ? 2 : 1);
csrc/fmha_cutlass_sm100.cu (1)

54-68: Unreachable code on line 67.

The return false; at line 67 is unreachable because both the if and else branches return. This is dead code.

Remove the unreachable return:

     } else {                                                                   \
       return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] {     \
         using c_type_out = nv_bfloat16;                                        \
         return __VA_ARGS__();                                                  \
       });                                                                      \
     }                                                                          \
-    return false;                                                              \
   }()
tests/attention/test_blackwell_fmha.py (1)

484-495: Main block test parameters don't match pytest parameterization.

The __main__ block uses num_qo_heads=1, num_kv_heads=1, sm_scale=1 but the pytest parameterization uses (128, 128) for heads and 1.0 / math.sqrt(192) for sm_scale. This inconsistency could cause confusion during manual debugging.

Consider aligning the main block parameters with the parameterized values:

 if __name__ == "__main__":
     test_blackwell_cutlass_fmha_fp8(
         batch_size=9,
         qo_len=377,
         kv_len=977,
-        num_qo_heads=1,
-        num_kv_heads=1,
+        num_qo_heads=128,
+        num_kv_heads=128,
         head_dim_qk=192,
         head_dim_vo=128,
-        sm_scale=1,
+        sm_scale=1.0 / math.sqrt(192),
         causal=False,
     )
benchmarks/bench_blackwell_attention.py (2)

29-39: o_data_type parameter is redundant and can be removed for clarity

Inside bench_fmha_blackwell, the o_data_type argument is immediately overwritten based solely on dtype, and the call sites always pass the same effective value. This makes the parameter misleading noise in the API.

Consider simplifying the function signature and calls like this:

-def bench_fmha_blackwell(
-    batch_size,
-    qkv_len,
-    num_qo_heads,
-    num_kv_heads,
-    head_dim_qk,
-    head_dim_vo,
-    causal,
-    dtype,
-    o_data_type,
-):
+def bench_fmha_blackwell(
+    batch_size,
+    qkv_len,
+    num_qo_heads,
+    num_kv_heads,
+    head_dim_qk,
+    head_dim_vo,
+    causal,
+    dtype,
+):
@@
-    # For FP8 input, output must be bfloat16
-    o_data_type = torch.bfloat16 if dtype.itemsize == 1 else dtype
+    # For FP8 input, output must be bfloat16
+    o_data_type = torch.bfloat16 if dtype.itemsize == 1 else dtype
@@
-            result_bf16 = bench_fmha_blackwell(
+            result_bf16 = bench_fmha_blackwell(
                 batch_size,
                 qkv_len,
                 num_qo_heads,
                 num_kv_heads,
                 head_dim_qk,
                 head_dim_vo,
                 causal,
-                torch.bfloat16,
-                o_data_type=torch.bfloat16,
+                torch.bfloat16,
             )
@@
-            result_fp8 = bench_fmha_blackwell(
+            result_fp8 = bench_fmha_blackwell(
                 batch_size,
                 qkv_len,
                 num_qo_heads,
                 num_kv_heads,
                 head_dim_qk,
                 head_dim_vo,
                 causal,
-                torch.float8_e4m3fn,
-                o_data_type=torch.bfloat16,
+                torch.float8_e4m3fn,
             )

This keeps the “FP8 → BF16 output” rule in one place and makes the public API less confusing.

Also applies to: 64-66, 145-153, 158-168, 179-189


197-217: Optional: Normalize dtype for easier CSV post-processing

Right now result["dtype"] is a dtype object, which csv.DictWriter will stringify (e.g., torch.bfloat16, torch.float8_e4m3fn). That’s valid CSV, but downstream parsing may be simpler if you store a compact, stable string like "bf16" / "fp8_e4m3fn" instead.

For example, just before appending results you could normalize:

for result in results:
    result["dtype"] = str(result["dtype"]).replace("torch.", "")
    writer.writerow(result)

This keeps the benchmark CSV a bit more tool-friendly if you analyze it later with pandas, spreadsheets, etc.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5fe01a2 and 6697a97.

📒 Files selected for processing (8)
  • benchmarks/bench_blackwell_attention.py (4 hunks)
  • csrc/fmha_cutlass_sm100.cu (3 hunks)
  • csrc/fmha_cutlass_sm100_binding.cu (1 hunks)
  • flashinfer/prefill.py (10 hunks)
  • include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp (1 hunks)
  • include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh (3 hunks)
  • include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp (1 hunks)
  • tests/attention/test_blackwell_fmha.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_blackwell_fmha.py (2)
flashinfer/utils.py (2)
  • is_sm100a_supported (531-533)
  • is_sm110a_supported (541-543)
flashinfer/prefill.py (10)
  • wrapper (135-151)
  • BatchPrefillWithRaggedKVCacheWrapper (2316-3111)
  • plan (1594-1982)
  • plan (2552-2849)
  • run (2013-2025)
  • run (2028-2040)
  • run (2043-2267)
  • run (2879-2889)
  • run (2892-2902)
  • run (2905-3080)
csrc/fmha_cutlass_sm100.cu (2)
include/flashinfer/attention/blackwell/device/sm100_mla.hpp (4)
  • stream (324-324)
  • stream (324-324)
  • stream (328-328)
  • stream (328-328)
include/flashinfer/attention/blackwell/device/fmha.hpp (4)
  • stream (240-240)
  • stream (240-240)
  • stream (244-244)
  • stream (244-244)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (18)
include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh (3)

79-80: LGTM! Per-activation scale parameters added consistently.

The new q_scale, k_scale, v_scale, and o_scale parameters are properly added to support FP8 quantization. The parameter ordering is consistent with the binding layer.


122-127: Operation::Arguments construction correctly propagates all scales.

The scales are passed to the mainloop arguments alongside the existing sm_scale, enabling FP8 calibration in the CUTLASS kernel.


167-177: Wrapper function signature and forwarding are consistent.

The run_fmha_fwd wrapper correctly forwards all scale parameters to FwdRunner::run, maintaining API consistency.

flashinfer/prefill.py (8)

2911-2914: FP8 scale parameters added to run() signature.

The new q_scale, k_scale, v_scale, o_scale parameters enable per-activation scaling for FP8 quantization. Good placement after existing positional args.


2933-2940: Clear documentation for FP8 scale parameters.

The docstrings accurately describe the purpose of each scale parameter and their default behavior.


2988-2994: Correct FP8-to-bf16 output handling.

Using q.dtype.itemsize == 1 to detect FP8 inputs and defaulting output to bf16 is the established pattern in this codebase.


3013-3016: Scale parameters correctly forwarded to fmha_varlen.

The scales are passed through consistently to the underlying CUTLASS implementation.


3198-3214: fmha_varlen signature consistently updated across all overloads.

The three function signatures (two @overload declarations and the implementation) are correctly synchronized with the new scale parameters.


3238-3245: Appropriate default scale values.

Setting all scales to 1.0 when None provides correct no-op behavior for non-FP8 paths and backward compatibility.


3264-3272: Consistent FP8 output dtype handling in fmha_varlen.

The bf16 output for FP8 inputs matches the pattern used elsewhere in this file.


3294-3298: Scale parameters correctly passed to module.run.

The parameter order (sm_scale, q_scale, k_scale, v_scale, o_scale, max_qo_len) matches the C++ binding signature.

csrc/fmha_cutlass_sm100_binding.cu (1)

20-26: Binding signature correctly updated for FP8 scale parameters.

The signature replacement of head/dim parameters with scale parameters (scale_q, scale_k, scale_v, o_scale) is consistent with the implementation file. Head/dim values are now derived from tensor shapes at runtime.

csrc/fmha_cutlass_sm100.cu (2)

88-99: Clean API simplification by deriving dimensions from tensor shapes.

Deriving num_qo_heads, num_kv_heads, head_dim_qk, and head_dim_vo from tensor dimensions (q.size(1), k.size(1), q.size(2), v.size(2)) reduces API surface and eliminates potential mismatches between declared and actual dimensions.


122-134: Scale parameters correctly forwarded to run_fmha_fwd.

The call correctly passes all scale parameters followed by the derived head/dimension values, maintaining consistency with the header file signature.

tests/attention/test_blackwell_fmha.py (3)

350-369: Good FP8 test coverage for DeepSeek-R1 configuration.

The parameterization targets the specific DeepSeek-R1 configuration (128 heads, 192/128 head dims) which is the primary FP8 use case for this PR.


393-428: Well-structured FP8 test setup.

The test correctly:

  1. Creates FP8 inputs by converting from half precision
  2. Configures the wrapper with explicit o_data_type=dtype_out
  3. Verifies output dtype assertion at line 431

437-481: Correct FP8 reference implementation with appropriate tolerances.

The reference correctly:

  1. Upcasts FP8 to float32 for computation accuracy
  2. Converts output to bf16 matching kernel output
  3. Uses relaxed tolerances (5e-2) appropriate for FP8 quantization error
benchmarks/bench_blackwell_attention.py (1)

86-94: and

Comment on lines +98 to +110
return {
"config_name": f"Blackwell-{config_name}",
"batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": dtype,
"time_ms": ms,
"tflops": TFLOPS,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Undefined config_name in return dict will crash benchmarks

bench_fmha_blackwell uses config_name in the returned dict, but config_name is neither a parameter nor a global; the first call will raise NameError at the return line. The caller already sets result["config_name"] = config_name after the call, so the cleanest fix is to drop this field from the function’s return value.

Apply this diff inside bench_fmha_blackwell:

-    return {
-        "config_name": f"Blackwell-{config_name}",
-        "batch_size": batch_size,
+    return {
+        "batch_size": batch_size,
         "qkv_len": qkv_len,
         "num_qo_heads": num_qo_heads,
         "num_kv_heads": num_kv_heads,
         "head_dim_qk": head_dim_qk,
         "head_dim_vo": head_dim_vo,
         "causal": causal,
         "dtype": dtype,
         "time_ms": ms,
         "tflops": TFLOPS,
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return {
"config_name": f"Blackwell-{config_name}",
"batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": dtype,
"time_ms": ms,
"tflops": TFLOPS,
}
return {
"batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": dtype,
"time_ms": ms,
"tflops": TFLOPS,
}
🤖 Prompt for AI Agents
In benchmarks/bench_blackwell_attention.py around lines 98 to 110, the return
dict from bench_fmha_blackwell references an undefined variable config_name
which will raise a NameError; remove the "config_name" key from the returned
dict (since the caller sets result["config_name"] already) so the function only
returns the valid fields (batch_size, qkv_len, num_qo_heads, num_kv_heads,
head_dim_qk, head_dim_vo, causal, dtype, time_ms, tflops).

Comment on lines +69 to +70
static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's read the file around lines 69-70 to confirm the current state
head -80 include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp | tail -20

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

# Check for any comments or documentation nearby that might explain the tuning
sed -n '50,100p' include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

# Look for similar register tuning patterns in related FMHA files
fd "fmha.*kernel" include/flashinfer/attention/blackwell/kernel/ --type f

Repository: flashinfer-ai/flashinfer

Length of output: 192


Stale comments for updated register budgets

Lines 69-70 have hard-coded both NumRegsCorrection and NumRegsOther to 64, but the trailing comments still describe the old conditional formulas, which is misleading for future maintenance.

Update the comments to clarify the tuning choice:

-  static const int NumRegsCorrection = 64;  // 96 - (kDebugUsingPrintf ? 16 : 0);
-  static const int NumRegsOther = 64;       // 32 + (kDebugUsingPrintf ? 16 : 0);
+  // Tuned register budgets for Correction and Other warp roles for SM100 FMHA.
+  static const int NumRegsCorrection = 64;
+  static const int NumRegsOther = 64;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0);
// Tuned register budgets for Correction and Other warp roles for SM100 FMHA.
static const int NumRegsCorrection = 64;
static const int NumRegsOther = 64;
🤖 Prompt for AI Agents
In
include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
around lines 69-70, the trailing comments next to NumRegsCorrection and
NumRegsOther are stale (they show old conditional formulas) and should be
updated to reflect the current fixed value; replace those misleading comments
with a brief explanation that both values are intentionally fixed to 64 for the
current SM100 register budget/tuning (and note that the previous conditional on
kDebugUsingPrintf was removed), or remove the comments entirely if unnecessary,
ensuring the comment clearly documents the tuning choice and rationale for
future maintainers.

@pavanimajety
Copy link
Contributor Author

@yzh119 Sorry for the delayed follow up on this. I updated similar to fa3 calls and cutlass Arguments for q_scale, k_scale, v_scale and o_scale in sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp#L173-L184

 struct Arguments {
    typename Load::Arguments load;

    float scale_softmax;

    // scaling factors to dequantize QKV
    float scale_q = 1.0f;
    float scale_k = 1.0f;
    float scale_v = 1.0f;

    // scaling factor to quantize O
    float inv_scale_o = 1.0f;
  };

Could you please review and provide feedback? thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants