Skip to content

Conversation

@Anerudhan
Copy link
Collaborator

@Anerudhan Anerudhan commented Nov 4, 2025

cudnn implementation for sdpa fp8

📌 Description

Allows cudnn SDPA be called when q and kv are both fp8

Requires cudnn 9.18.0

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Expanded FP8 support with per-head calibration tensors and optional output data-type; added "cudnn-native" backend option.
  • Bug Fixes

    • Improved FP8 scale propagation across execution paths and more deterministic output initialization (now zeroed).
  • Tests

    • Added comprehensive FP8 validation test covering many configurations and backends; adjusted FP8 tolerances.
  • Documentation

    • Updated benchmark CLI/help and runtime checks to include cudnn-native and clarify backend/version behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 4, 2025

Walkthrough

Adds FP8-aware cuDNN prefill plumbing: per-device and dummy scale tensors, new UIDs for scale/descale/amax, threading of q/k/v scale tensors and optional output dtype through graph/runtime, and FP8-focused tests and benchmark/harness updates.

Changes

Cohort / File(s) Summary
cuDNN FP8 prefill plumbing
flashinfer/cudnn/prefill.py
Added _get_dummy_scale_tensor(device); _create_cudnn_handle(stream) now returns handle; added UIDs for Q/K/V/S/S_DESCALE/O scales and S_AMAX/O_AMAX; extended signatures to accept q_scale,k_scale,v_scale,o_data_type; compute derived cuDNN data types; create and wire FP8 scale/descale/amax tensors into graph outputs; propagate scale tensors and dummy mappings into var_map.
Public API: prefill.run scale extensions
flashinfer/prefill.py
Extended run signature to accept q_scale/k_scale/v_scale: Optional[Union[float, torch.Tensor]]; updated docstrings; apply sm_scale only for non-cuDNN backend; initialize out as zeros; forward scale tensors through cuDNN and paged_run paths to kernel invocations.
FP8 test coverage
tests/attention/test_cudnn_prefill.py
Replaced some test inputs with ones; removed standalone output_ref init; added test_cudnn_prefill_fp8 (parametrized) that constructs FP8 q/k/v and per-head scales, runs cuDNN and reference paths, and asserts outputs/amax/descale agreement.
Benchmark & harness updates
benchmarks/README.md, benchmarks/routines/attention.py, benchmarks/routines/flashinfer_benchmark_utils.py
Added cudnn-native backend to CLI and backend lists; added cuDNN availability/version checks and FP8 feasibility gating; introduced FP8 conversion/scale handling and adjusted FP8 tolerances and reference-backend selection logic.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay extra attention to:
    • cuDNN graph construction: derived data-type selection, new UIDs, and wiring of scale/descale/amax tensors.
    • Correct device placement, caching and lifetime of dummy scale tensors from _get_dummy_scale_tensor.
    • Propagation, shapes and dtype semantics of q_scale/k_scale/v_scale through API → cudnn paths and kernel calls.
    • Tests: FP8 test correctness, chosen tolerances, and reference-path equivalence.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • wenscarl
  • aleozlx
  • bkryu
  • nvmbreughe
  • jiahanc

Poem

🐰
I hopped through graphs with tiny scales,
Seeded tensors on device trails.
UIDs aligned and amax in tow,
FP8 whispers where cuDNN flows.
A carrot clap—small bytes aglow! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title is truncated and vague, using '…' without fully specifying what was implemented, making it unclear whether it describes the main change adequately. Complete the title to clearly state the main change, such as 'Add FP8 support for cuDNN SDPA with Q and KV caches' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description covers the main objective (FP8 support for cuDNN SDPA requiring v9.18.0) and confirms pre-commit and tests completed, meeting the template requirements.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Anerudhan Anerudhan force-pushed the feature/cudnn/sdpa_fp8_qkv branch 2 times, most recently from d41b0c5 to 78ba024 Compare December 12, 2025 05:28
@Anerudhan Anerudhan marked this pull request as ready for review December 12, 2025 05:30
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cudnn/prefill.py (1)

73-126: Fix graph cache key: include dtypes (esp. o_data_type) or you can reuse an incompatible graph.
_sdpa_prefill_key_fn now accepts o_data_type but doesn’t use it in key, and the key also doesn’t include q/k/v dtypes even though _build_prefill_graph now depends on them. This can cause cached graphs to be reused across dtype configurations (BF16 vs FP8 vs output dtype), leading to wrong results or execution failures.

Also applies to: 94-95, 111-125

🧹 Nitpick comments (3)
tests/attention/test_cudnn_prefill.py (1)

43-46: Consider keeping some randomness in test_cudnn_prefill inputs (ones-only can reduce coverage).
Switching q/kv_cache to torch.ones can mask issues that only show up with realistic value distributions (e.g., scale/amax paths, numerical stability). If you changed this to reduce nondeterminism, consider using torch.randn(...)*small_scale instead to keep signal while staying stable.

Also applies to: 59-73

flashinfer/prefill.py (1)

2153-2159: out = torch.zeros(...) may be a perf regression if kernels fully overwrite out.
If this was added to fix partial-writes (e.g., masked regions), consider documenting why. Otherwise, prefer torch.empty(...) for speed.

flashinfer/cudnn/prefill.py (1)

170-177: Optimize generate_stats for FP8 to match BFLOAT16 pattern.
The version guard is correct. However, generate_stats=True is unconditionally set for FP8 (line 382) while Stats is only output when return_lse=True (line 430). Since this is prefill-only (no backward pass), you can tie it to return_lse like the BFLOAT16 path (line 354: generate_stats=return_lse) to avoid unnecessary overhead when LSE is not requested. The amax tensors are produced independently of generate_stats and are handled correctly.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9ef49b and 78ba024.

📒 Files selected for processing (3)
  • flashinfer/cudnn/prefill.py (18 hunks)
  • flashinfer/prefill.py (5 hunks)
  • tests/attention/test_cudnn_prefill.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_cudnn_prefill.py (1)
flashinfer/prefill.py (10)
  • BatchPrefillWithPagedKVCacheWrapper (1305-2312)
  • plan (1595-1988)
  • plan (2565-2862)
  • run (2019-2031)
  • run (2034-2046)
  • run (2049-2280)
  • run (2892-2902)
  • run (2905-2915)
  • run (2918-3075)
  • wrapper (135-151)
flashinfer/cudnn/prefill.py (1)
flashinfer/cudnn/decode.py (2)
  • _create_cudnn_handle (21-26)
  • UIDs (30-49)
🪛 Ruff (0.14.8)
tests/attention/test_cudnn_prefill.py

205-205: Unused function argument: return_lse

(ARG001)


206-206: Unused function argument: is_cuda_graph_compatible

(ARG001)

flashinfer/cudnn/prefill.py

94-94: Unused function argument: o_data_type

(ARG001)


162-162: Avoid specifying long messages outside the exception class

(TRY003)


174-176: Avoid specifying long messages outside the exception class

(TRY003)


337-337: Ambiguous variable name: O

(E741)


372-372: Ambiguous variable name: O

(E741)

🔇 Additional comments (2)
flashinfer/prefill.py (2)

69-87: _split_scale_param looks good for scalar-vs-tensor scale normalization.
Clear semantics and keeps FA3 callsites cleaner.


2189-2210: cuDNN FP8 plumbing is correctly threaded through wrapper → cudnn call.
Forwarding q_scale/k_scale/v_scale and o_data_type matches the new FP8 SDPA graph inputs and enables the intended path. The cudnn_batch_prefill_with_kv_cache function enforces cuDNN's requirement that scale tensors have shape (1, 1, 1, 1) and properly maps them to the cuDNN graph's descale/scale nodes.

@bkryu
Copy link
Collaborator

bkryu commented Dec 12, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !191 has been created, and the CI pipeline #40107625 is currently running. I'll report back once the pipeline job completes.

…he cudnn implementation

Making output to bf16. Debugging commit

Fixed and cleaned up
@Anerudhan Anerudhan force-pushed the feature/cudnn/sdpa_fp8_qkv branch from 78ba024 to 880afa2 Compare December 12, 2025 18:28
@Anerudhan
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@Anerudhan is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #40107625: canceled

@bkryu
Copy link
Collaborator

bkryu commented Dec 12, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !191 has been updated with latest changes, and the CI pipeline #40109767 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/cudnn/prefill.py (1)

75-128: Graph cache key must include dtypes (and o_data_type), otherwise wrong graph can be reused.
_sdpa_prefill_key_fn doesn’t include q.dtype/k_cache.dtype/v_cache.dtype nor o_data_type in key. Now that _build_prefill_graph selects different ops (sdpa vs sdpa_fp8) and sets different tensor data_types based on dtype + output dtype, a cached graph can be invalid for a later call with different dtypes/output dtype.

 def _sdpa_prefill_key_fn(
@@
     lse: Optional[torch.Tensor] = None,
     o_data_type: Optional[torch.dtype] = None,
 ):
+    if o_data_type is None:
+        o_data_type = q.dtype
     graph_b = actual_seq_lens_q.shape[0]
@@
     key = (
         graph_b,
+        str(q.dtype),
+        str(k_cache.dtype),
+        str(v_cache.dtype),
+        str(o_data_type),
         q.dim(),
         k_cache.dim(),
@@
         page_size,
     )
     return key

Also: since actual_seq_lens_q is typed Optional but used unconditionally (.shape[0]), either make it required in the signature or handle None.

flashinfer/prefill.py (1)

2182-2210: cuDNN path needs scale normalization: q_scale/k_scale/v_scale can be float or per-head tensors, but cudnn_batch_prefill_with_kv_cache expects (1,1,1,1) float32 tensors.
Currently, these scales are passed directly from the method signature (which allows Union[float, torch.Tensor]) to cudnn_batch_prefill_with_kv_cache() without transformation, causing failures when a Python float is provided and ambiguity with non-scalar tensors.

         if self._backend == "cudnn":
+            def _as_cudnn_scale(x):
+                if x is None:
+                    return None
+                if isinstance(x, torch.Tensor):
+                    if x.numel() != 1:
+                        raise ValueError("cudnn backend expects scalar scale tensors (numel()==1).")
+                    t = x.to(device=q.device, dtype=torch.float32)
+                else:
+                    t = torch.tensor(float(x), device=q.device, dtype=torch.float32)
+                return t.reshape(1, 1, 1, 1)
+
             cudnn_batch_prefill_with_kv_cache(
@@
-                q_scale=q_scale,
-                k_scale=k_scale,
-                v_scale=v_scale,
+                q_scale=_as_cudnn_scale(q_scale),
+                k_scale=_as_cudnn_scale(k_scale),
+                v_scale=_as_cudnn_scale(v_scale),
@@
                 o_data_type=out_dtype,
             )
♻️ Duplicate comments (4)
tests/attention/test_cudnn_prefill.py (2)

14-26: Parametrized return_lse / is_cuda_graph_compatible are unused.
Either remove those parametrizations or rename args to _return_lse / _is_cuda_graph_compatible (or actually pass return_lse=... into wrapper.run(...) and assert on the returned tuple).

Also applies to: 190-203


225-272: FP8 test passes incorrect scale types/shapes (and q_scale float), likely breaking cuDNN var_map expectations.

  • flashinfer/cudnn/prefill.py builds scale tensors with dim=(1,1,1,1); the test constructs scalar tensors for k_scale_tensor/v_scale_tensor.
  • The test passes q_scale=q_scale where q_scale is a Python float (the tensor q_scale = torch.tensor(...) is created but not used).
    Also clamp scales to avoid divide-by-zero when amax==0.
-    q_scale = q.amax().item() / 256
-
-    q_scale = torch.tensor(q_scale, device=device, dtype=torch.float32)
-    q_fp8 = (q / q_scale).to(torch.float8_e4m3fn)
+    q_scale_val = (q.amax().float() / 256).clamp_min(1e-12)
+    q_scale = q_scale_val.to(device=device, dtype=torch.float32).reshape(1, 1, 1, 1)
+    q_fp8 = (q / q_scale_val).to(torch.float8_e4m3fn)
@@
-    k_scale = k_cache.amax().item() / 256
-    v_scale = v_cache.amax().item() / 256
-    k_cache_fp8 = (k_cache / k_scale).to(torch.float8_e4m3fn)
-    v_cache_fp8 = (v_cache / v_scale).to(torch.float8_e4m3fn)
-
-    k_scale_tensor = torch.tensor(k_scale, device=device, dtype=torch.float32)
-    v_scale_tensor = torch.tensor(v_scale, device=device, dtype=torch.float32)
+    k_scale_val = (k_cache.amax().float() / 256).clamp_min(1e-12)
+    v_scale_val = (v_cache.amax().float() / 256).clamp_min(1e-12)
+    k_cache_fp8 = (k_cache / k_scale_val).to(torch.float8_e4m3fn)
+    v_cache_fp8 = (v_cache / v_scale_val).to(torch.float8_e4m3fn)
+    k_scale_tensor = k_scale_val.to(device=device, dtype=torch.float32).reshape(1, 1, 1, 1)
+    v_scale_tensor = v_scale_val.to(device=device, dtype=torch.float32).reshape(1, 1, 1, 1)
@@
-        q_scale=q_scale,
+        q_scale=q_scale,
         k_scale=k_scale_tensor,
         v_scale=v_scale_tensor,

Also applies to: 340-346

flashinfer/cudnn/prefill.py (1)

648-651: o_data_type can still be None when allocating out (will crash).
torch.empty(... dtype=o_data_type) will throw if o_data_type is None. This matches the prior review concern and still needs a default before allocation.

-    if out is None:
+    if out is None:
+        if o_data_type is None:
+            o_data_type = q.dtype
         out_shape = (num_tokens, h_qo, d_vo)
         out = torch.empty(out_shape, device=q.device, dtype=o_data_type)
flashinfer/prefill.py (1)

2054-2056: Avoid turning sm_scale into a Tensor for non-cuDNN backends.
sm_scale *= q_scale / *= k_scale will produce a Tensor if either scale is a Tensor, which can break kernels expecting a Python float.

-        if self._backend != "cudnn":
-            if q_scale is not None:
-                sm_scale *= q_scale
-            if k_scale is not None:
-                sm_scale *= k_scale
+        if self._backend != "cudnn":
+            def _to_scalar(x):
+                if isinstance(x, torch.Tensor):
+                    return float(x.item()) if x.numel() == 1 else None
+                return float(x) if isinstance(x, (int, float)) else None
+            qs = _to_scalar(q_scale)
+            ks = _to_scalar(k_scale)
+            if qs is not None:
+                sm_scale *= qs
+            if ks is not None:
+                sm_scale *= ks

Also applies to: 2134-2139

🧹 Nitpick comments (3)
flashinfer/cudnn/prefill.py (1)

20-28: Per-device dummy scale cache looks good; consider thread-safety.
Caching _dummy_scale_tensors by torch.device addresses the multi-GPU wrong-device risk. If this can be hit from multiple threads, consider guarding the dict with a lock (or using setdefault) to avoid races.

flashinfer/prefill.py (1)

2154-2158: torch.zeros for out is potentially expensive if kernels fully overwrite output.
If all backends always write the full out, torch.empty is typically preferable. If zeros are required for a backend/edge-case, consider scoping it narrowly to that backend.

tests/attention/test_cudnn_prefill.py (1)

44-61: Switching inputs to torch.ones reduces coverage of numerical edge cases.
This makes the test more deterministic, but it also weakens stress on masking/scale/accumulation behavior. If the intent is determinism, consider keeping random inputs but fixing seed (already done) and perhaps using smaller tolerances instead.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 78ba024 and 880afa2.

📒 Files selected for processing (3)
  • flashinfer/cudnn/prefill.py (18 hunks)
  • flashinfer/prefill.py (5 hunks)
  • tests/attention/test_cudnn_prefill.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_cudnn_prefill.py (1)
flashinfer/prefill.py (10)
  • BatchPrefillWithPagedKVCacheWrapper (1305-2312)
  • plan (1595-1988)
  • plan (2565-2862)
  • run (2019-2031)
  • run (2034-2046)
  • run (2049-2280)
  • run (2892-2902)
  • run (2905-2915)
  • run (2918-3075)
  • wrapper (135-151)
flashinfer/cudnn/prefill.py (1)
flashinfer/cudnn/decode.py (1)
  • UIDs (30-49)
🪛 Ruff (0.14.8)
tests/attention/test_cudnn_prefill.py

201-201: Unused function argument: return_lse

(ARG001)


202-202: Unused function argument: is_cuda_graph_compatible

(ARG001)

flashinfer/cudnn/prefill.py

96-96: Unused function argument: o_data_type

(ARG001)


164-164: Avoid specifying long messages outside the exception class

(TRY003)


179-181: Avoid specifying long messages outside the exception class

(TRY003)


342-342: Ambiguous variable name: O

(E741)


377-377: Ambiguous variable name: O

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #40109767: canceled

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
benchmarks/routines/attention.py (1)

957-964: Consider consolidating the duplicated to_float8 helper.

This function is duplicated across multiple files (see benchmarks/routines/gemm.py, tests/attention/test_trtllm_gen_attention.py, benchmarks/bench_trtllm_fmha.py). While the implementations vary slightly (0.1 scaling factor vs. no scaling), consolidating into a shared utility would improve maintainability.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 880afa2 and 74be215.

📒 Files selected for processing (3)
  • benchmarks/README.md (1 hunks)
  • benchmarks/routines/attention.py (10 hunks)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/routines/attention.py (4)
tests/attention/test_trtllm_gen_attention.py (1)
  • to_float8 (39-45)
benchmarks/routines/gemm.py (1)
  • to_float8 (160-166)
benchmarks/bench_trtllm_fmha.py (1)
  • to_float8 (62-68)
flashinfer/cudnn/prefill.py (1)
  • cudnn_batch_prefill_with_kv_cache (554-720)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

176-184: LGTM! Backend expansion is well-documented.

The addition of cudnn-native across compute capabilities 8.0–12.0 is consistent with the broader PR changes. The clarifying comment helps distinguish between the two cuDNN-related backends.

benchmarks/routines/attention.py (7)

7-22: LGTM! Robust cuDNN import handling.

The import logic correctly handles both missing imports and library loading failures, with appropriate filtering to re-raise unexpected OSErrors.


107-107: LGTM! Consistent backend addition.

The cudnn-native choice aligns with the backend expansion across the PR.


700-706: LGTM! Appropriate tolerance relaxation for FP8.

The relaxed tolerances (rtol=5e-1, atol=1e-1) are reasonable given the reduced precision of FP8 data types.


966-993: LGTM! Proper FP8 scale handling for multiple backends.

The code correctly maintains two conversion paths:

  • Scalar scales for non-cuDNN backends (fa2, fa3, trtllm)
  • Tensor scales for cuDNN-related backends

The separate handling of k_cache_cudnn/v_cache_cudnn vs. kv_cache accommodates different backend tensor layout requirements.


1027-1115: LGTM! Backend-specific handling is correct.

The code properly accommodates different backend requirements:

  • cudnn uses the wrapper API with NHD layout and tensor scales
  • trtllm-native computes combined BMM scales with defensive defaults
  • cudnn-native uses the direct API with tensor scales

The defensive scale defaults (lines 1074-1076) prevent potential None-related errors.


1148-1161: LGTM! Flexible reference backend selection.

The fallback logic enables testing when FA2 is unavailable (e.g., for FP8 scenarios). The priority order (cudnn → cudnn-native → trtllm-gen → trtllm-native) is sensible, and the informative logging at line 1159 aids debugging.


1180-1180: LGTM! Error message now reflects actual reference backend.

This change correctly reports the actual reference backend used (which may not always be FA2), improving error message accuracy.

benchmarks/README.md (1)

120-120: LGTM! Documentation updated to reflect new backend.

The addition of cudnn-native to the backends documentation is consistent with the code changes across the PR.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cudnn/prefill.py (1)

75-129: Fix graph-cache key: include k_cache.dtype, v_cache.dtype, and o_data_type.
Right now the cache key includes q.dtype (Line 116) but not k_cache.dtype / v_cache.dtype, and o_data_type is accepted but ignored—so _build_prefill_graph can be incorrectly reused across calls with different KV dtypes or output dtype.

 def _sdpa_prefill_key_fn(
@@
-    key = (
+    key = (
         graph_b,
         q.dim(),
         q.dtype,
+        k_cache.dtype,
+        v_cache.dtype,
         k_cache.dim(),
@@
         page_size,
+        o_data_type,
     )
     return key
♻️ Duplicate comments (2)
flashinfer/cudnn/prefill.py (2)

164-182: Don’t default FP8 output dtype to FP8 implicitly; validate/require fp16/bf16 for FP8 Q/KV.
o_data_type = q.dtype (Line 171-173; also Line 649-651) will make cudnn_o_data_type FP8 when q is FP8, and you then force O.set_data_type(cudnn_o_data_type) (Line 430-435). If cuDNN sdpa_fp8 expects/produces fp16/bf16 outputs by design, this can be invalid or silently wrong.

cuDNN 9.18 SDPA FP8 (sdpa_fp8): what output tensor dtypes are supported (FP16/BF16 only vs FP8 output supported)? Is output dtype configurable via graph tensor data_type?

Also applies to: 430-435, 649-655


200-252: FP8 scale tensors must be required/validated; current conditional var_map can crash graph execution.
When q is FP8, the graph always creates scale tensors and sets UIDs for Q_SCALE_UID/K_SCALE_UID/V_SCALE_UID/S_SCALE_UID/S_DESCALE_UID/O_SCALE_UID (Line 200-252), but _batch_prefill_with_kv_cache only populates those UIDs if the caller passed the corresponding tensors (Line 534-544). Missing any scale will typically fail inside graph.execute() with hard-to-debug cuDNN errors.

Suggested guard (place near the start of _batch_prefill_with_kv_cache, before _build_prefill_graph):

 def _batch_prefill_with_kv_cache(
@@
 ) -> tuple[torch.Tensor, torch.Tensor]:
+    # Validate required FP8 scales early to avoid opaque cuDNN execution failures.
+    if q.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
+        missing = [name for name, t in (("q_scale", q_scale), ("k_scale", k_scale), ("v_scale", v_scale)) if t is None]
+        if missing:
+            raise ValueError(f"FP8 prefill requires scale tensors: missing {', '.join(missing)}.")
+
     graph, tensors = _build_prefill_graph(
         q=q,
@@
     )

Also consider verifying q_scale.device == q.device and shape (1,1,1,1) here.

Also applies to: 534-544, 656-681

🧹 Nitpick comments (2)
flashinfer/cudnn/prefill.py (2)

339-419: sdpa_fp8(..., generate_stats=True) ignores return_lse; consider aligning behavior or enforcing return_lse=True for FP8.
If return_lse=False, you still request stats/amax outputs, but you only mark/return Stats when return_lse is true later. This is likely fine but extra overhead and a potential mismatch with the public doc (“return_lse must be True”).


96-97: Remove unused _sdpa_prefill_key_fn(..., o_data_type=...) param if you won’t key on it.
Right now it’s unused (ruff ARG001) and increases the chance callers assume it affects caching. If you apply the cache-key fix above, keep it; otherwise drop it.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 74be215 and 0f2c303.

📒 Files selected for processing (1)
  • flashinfer/cudnn/prefill.py (19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/cudnn/prefill.py (1)
flashinfer/cudnn/decode.py (2)
  • _create_cudnn_handle (21-26)
  • UIDs (30-49)
🪛 Ruff (0.14.8)
flashinfer/cudnn/prefill.py

96-96: Unused function argument: o_data_type

(ARG001)


165-165: Avoid specifying long messages outside the exception class

(TRY003)


180-182: Avoid specifying long messages outside the exception class

(TRY003)


343-343: Ambiguous variable name: O

(E741)


378-378: Ambiguous variable name: O

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/cudnn/prefill.py (3)

20-28: Per-device dummy scale cache looks good; consider device key normalization / thread-safety only if needed.
This resolves the multi-GPU wrong-device tensor risk from the old global singleton; the (1,1,1,1) shape also matches the graph tensors.


31-37: Handle return matches decode.py; good consistency.
Returning the handle makes the helper usable in more contexts and matches the pattern in flashinfer/cudnn/decode.py.


64-72: No UID range collision detected. The new scale/amax UIDs (150–155 and 160–161) in prefill.py do not overlap with existing UIDs in decode.py (which uses ranges 0–3, 50–52, 100–101, 200–202, 1000–1001) or other modules in the codebase. The numeric ranges are cleanly separated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants