-
Notifications
You must be signed in to change notification settings - Fork 594
Added an initial implementation of Q and KV Cache in fp8 and to use t… #2035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added an initial implementation of Q and KV Cache in fp8 and to use t… #2035
Conversation
WalkthroughAdds FP8-aware cuDNN prefill plumbing: per-device and dummy scale tensors, new UIDs for scale/descale/amax, threading of q/k/v scale tensors and optional output dtype through graph/runtime, and FP8-focused tests and benchmark/harness updates. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
d41b0c5 to
78ba024
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/cudnn/prefill.py (1)
73-126: Fix graph cache key: include dtypes (esp.o_data_type) or you can reuse an incompatible graph.
_sdpa_prefill_key_fnnow acceptso_data_typebut doesn’t use it inkey, and the key also doesn’t include q/k/v dtypes even though_build_prefill_graphnow depends on them. This can cause cached graphs to be reused across dtype configurations (BF16 vs FP8 vs output dtype), leading to wrong results or execution failures.Also applies to: 94-95, 111-125
🧹 Nitpick comments (3)
tests/attention/test_cudnn_prefill.py (1)
43-46: Consider keeping some randomness intest_cudnn_prefillinputs (ones-only can reduce coverage).
Switchingq/kv_cachetotorch.onescan mask issues that only show up with realistic value distributions (e.g., scale/amax paths, numerical stability). If you changed this to reduce nondeterminism, consider usingtorch.randn(...)*small_scaleinstead to keep signal while staying stable.Also applies to: 59-73
flashinfer/prefill.py (1)
2153-2159:out = torch.zeros(...)may be a perf regression if kernels fully overwriteout.
If this was added to fix partial-writes (e.g., masked regions), consider documenting why. Otherwise, prefertorch.empty(...)for speed.flashinfer/cudnn/prefill.py (1)
170-177: Optimizegenerate_statsfor FP8 to match BFLOAT16 pattern.
The version guard is correct. However,generate_stats=Trueis unconditionally set for FP8 (line 382) while Stats is only output whenreturn_lse=True(line 430). Since this is prefill-only (no backward pass), you can tie it toreturn_lselike the BFLOAT16 path (line 354:generate_stats=return_lse) to avoid unnecessary overhead when LSE is not requested. The amax tensors are produced independently ofgenerate_statsand are handled correctly.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/cudnn/prefill.py(18 hunks)flashinfer/prefill.py(5 hunks)tests/attention/test_cudnn_prefill.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_cudnn_prefill.py (1)
flashinfer/prefill.py (10)
BatchPrefillWithPagedKVCacheWrapper(1305-2312)plan(1595-1988)plan(2565-2862)run(2019-2031)run(2034-2046)run(2049-2280)run(2892-2902)run(2905-2915)run(2918-3075)wrapper(135-151)
flashinfer/cudnn/prefill.py (1)
flashinfer/cudnn/decode.py (2)
_create_cudnn_handle(21-26)UIDs(30-49)
🪛 Ruff (0.14.8)
tests/attention/test_cudnn_prefill.py
205-205: Unused function argument: return_lse
(ARG001)
206-206: Unused function argument: is_cuda_graph_compatible
(ARG001)
flashinfer/cudnn/prefill.py
94-94: Unused function argument: o_data_type
(ARG001)
162-162: Avoid specifying long messages outside the exception class
(TRY003)
174-176: Avoid specifying long messages outside the exception class
(TRY003)
337-337: Ambiguous variable name: O
(E741)
372-372: Ambiguous variable name: O
(E741)
🔇 Additional comments (2)
flashinfer/prefill.py (2)
69-87:_split_scale_paramlooks good for scalar-vs-tensor scale normalization.
Clear semantics and keeps FA3 callsites cleaner.
2189-2210: cuDNN FP8 plumbing is correctly threaded through wrapper → cudnn call.
Forwardingq_scale/k_scale/v_scaleando_data_typematches the new FP8 SDPA graph inputs and enables the intended path. Thecudnn_batch_prefill_with_kv_cachefunction enforces cuDNN's requirement that scale tensors have shape (1, 1, 1, 1) and properly maps them to the cuDNN graph's descale/scale nodes.
|
/bot run |
…he cudnn implementation Making output to bf16. Debugging commit Fixed and cleaned up
78ba024 to
880afa2
Compare
|
/bot run |
|
@Anerudhan is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
[CANCELING] Pipeline #40107625: canceled |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/cudnn/prefill.py (1)
75-128: Graph cache key must include dtypes (and o_data_type), otherwise wrong graph can be reused.
_sdpa_prefill_key_fndoesn’t includeq.dtype/k_cache.dtype/v_cache.dtypenoro_data_typeinkey. Now that_build_prefill_graphselects different ops (sdpavssdpa_fp8) and sets different tensordata_types based on dtype + output dtype, a cached graph can be invalid for a later call with different dtypes/output dtype.def _sdpa_prefill_key_fn( @@ lse: Optional[torch.Tensor] = None, o_data_type: Optional[torch.dtype] = None, ): + if o_data_type is None: + o_data_type = q.dtype graph_b = actual_seq_lens_q.shape[0] @@ key = ( graph_b, + str(q.dtype), + str(k_cache.dtype), + str(v_cache.dtype), + str(o_data_type), q.dim(), k_cache.dim(), @@ page_size, ) return keyAlso: since
actual_seq_lens_qis typed Optional but used unconditionally (.shape[0]), either make it required in the signature or handleNone.flashinfer/prefill.py (1)
2182-2210: cuDNN path needs scale normalization: q_scale/k_scale/v_scale can be float or per-head tensors, but cudnn_batch_prefill_with_kv_cache expects (1,1,1,1) float32 tensors.
Currently, these scales are passed directly from the method signature (which allowsUnion[float, torch.Tensor]) tocudnn_batch_prefill_with_kv_cache()without transformation, causing failures when a Python float is provided and ambiguity with non-scalar tensors.if self._backend == "cudnn": + def _as_cudnn_scale(x): + if x is None: + return None + if isinstance(x, torch.Tensor): + if x.numel() != 1: + raise ValueError("cudnn backend expects scalar scale tensors (numel()==1).") + t = x.to(device=q.device, dtype=torch.float32) + else: + t = torch.tensor(float(x), device=q.device, dtype=torch.float32) + return t.reshape(1, 1, 1, 1) + cudnn_batch_prefill_with_kv_cache( @@ - q_scale=q_scale, - k_scale=k_scale, - v_scale=v_scale, + q_scale=_as_cudnn_scale(q_scale), + k_scale=_as_cudnn_scale(k_scale), + v_scale=_as_cudnn_scale(v_scale), @@ o_data_type=out_dtype, )
♻️ Duplicate comments (4)
tests/attention/test_cudnn_prefill.py (2)
14-26: Parametrizedreturn_lse/is_cuda_graph_compatibleare unused.
Either remove those parametrizations or rename args to_return_lse/_is_cuda_graph_compatible(or actually passreturn_lse=...intowrapper.run(...)and assert on the returned tuple).Also applies to: 190-203
225-272: FP8 test passes incorrect scale types/shapes (andq_scalefloat), likely breaking cuDNN var_map expectations.
flashinfer/cudnn/prefill.pybuilds scale tensors withdim=(1,1,1,1); the test constructs scalar tensors fork_scale_tensor/v_scale_tensor.- The test passes
q_scale=q_scalewhereq_scaleis a Python float (the tensorq_scale = torch.tensor(...)is created but not used).
Also clamp scales to avoid divide-by-zero whenamax==0.- q_scale = q.amax().item() / 256 - - q_scale = torch.tensor(q_scale, device=device, dtype=torch.float32) - q_fp8 = (q / q_scale).to(torch.float8_e4m3fn) + q_scale_val = (q.amax().float() / 256).clamp_min(1e-12) + q_scale = q_scale_val.to(device=device, dtype=torch.float32).reshape(1, 1, 1, 1) + q_fp8 = (q / q_scale_val).to(torch.float8_e4m3fn) @@ - k_scale = k_cache.amax().item() / 256 - v_scale = v_cache.amax().item() / 256 - k_cache_fp8 = (k_cache / k_scale).to(torch.float8_e4m3fn) - v_cache_fp8 = (v_cache / v_scale).to(torch.float8_e4m3fn) - - k_scale_tensor = torch.tensor(k_scale, device=device, dtype=torch.float32) - v_scale_tensor = torch.tensor(v_scale, device=device, dtype=torch.float32) + k_scale_val = (k_cache.amax().float() / 256).clamp_min(1e-12) + v_scale_val = (v_cache.amax().float() / 256).clamp_min(1e-12) + k_cache_fp8 = (k_cache / k_scale_val).to(torch.float8_e4m3fn) + v_cache_fp8 = (v_cache / v_scale_val).to(torch.float8_e4m3fn) + k_scale_tensor = k_scale_val.to(device=device, dtype=torch.float32).reshape(1, 1, 1, 1) + v_scale_tensor = v_scale_val.to(device=device, dtype=torch.float32).reshape(1, 1, 1, 1) @@ - q_scale=q_scale, + q_scale=q_scale, k_scale=k_scale_tensor, v_scale=v_scale_tensor,Also applies to: 340-346
flashinfer/cudnn/prefill.py (1)
648-651:o_data_typecan still be None when allocatingout(will crash).
torch.empty(... dtype=o_data_type)will throw ifo_data_typeis None. This matches the prior review concern and still needs a default before allocation.- if out is None: + if out is None: + if o_data_type is None: + o_data_type = q.dtype out_shape = (num_tokens, h_qo, d_vo) out = torch.empty(out_shape, device=q.device, dtype=o_data_type)flashinfer/prefill.py (1)
2054-2056: Avoid turningsm_scaleinto a Tensor for non-cuDNN backends.
sm_scale *= q_scale/*= k_scalewill produce a Tensor if either scale is a Tensor, which can break kernels expecting a Python float.- if self._backend != "cudnn": - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale + if self._backend != "cudnn": + def _to_scalar(x): + if isinstance(x, torch.Tensor): + return float(x.item()) if x.numel() == 1 else None + return float(x) if isinstance(x, (int, float)) else None + qs = _to_scalar(q_scale) + ks = _to_scalar(k_scale) + if qs is not None: + sm_scale *= qs + if ks is not None: + sm_scale *= ksAlso applies to: 2134-2139
🧹 Nitpick comments (3)
flashinfer/cudnn/prefill.py (1)
20-28: Per-device dummy scale cache looks good; consider thread-safety.
Caching_dummy_scale_tensorsbytorch.deviceaddresses the multi-GPU wrong-device risk. If this can be hit from multiple threads, consider guarding the dict with a lock (or usingsetdefault) to avoid races.flashinfer/prefill.py (1)
2154-2158:torch.zerosforoutis potentially expensive if kernels fully overwrite output.
If all backends always write the fullout,torch.emptyis typically preferable. If zeros are required for a backend/edge-case, consider scoping it narrowly to that backend.tests/attention/test_cudnn_prefill.py (1)
44-61: Switching inputs totorch.onesreduces coverage of numerical edge cases.
This makes the test more deterministic, but it also weakens stress on masking/scale/accumulation behavior. If the intent is determinism, consider keeping random inputs but fixing seed (already done) and perhaps using smaller tolerances instead.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/cudnn/prefill.py(18 hunks)flashinfer/prefill.py(5 hunks)tests/attention/test_cudnn_prefill.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_cudnn_prefill.py (1)
flashinfer/prefill.py (10)
BatchPrefillWithPagedKVCacheWrapper(1305-2312)plan(1595-1988)plan(2565-2862)run(2019-2031)run(2034-2046)run(2049-2280)run(2892-2902)run(2905-2915)run(2918-3075)wrapper(135-151)
flashinfer/cudnn/prefill.py (1)
flashinfer/cudnn/decode.py (1)
UIDs(30-49)
🪛 Ruff (0.14.8)
tests/attention/test_cudnn_prefill.py
201-201: Unused function argument: return_lse
(ARG001)
202-202: Unused function argument: is_cuda_graph_compatible
(ARG001)
flashinfer/cudnn/prefill.py
96-96: Unused function argument: o_data_type
(ARG001)
164-164: Avoid specifying long messages outside the exception class
(TRY003)
179-181: Avoid specifying long messages outside the exception class
(TRY003)
342-342: Ambiguous variable name: O
(E741)
377-377: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
[CANCELING] Pipeline #40109767: canceled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
benchmarks/routines/attention.py (1)
957-964: Consider consolidating the duplicatedto_float8helper.This function is duplicated across multiple files (see
benchmarks/routines/gemm.py,tests/attention/test_trtllm_gen_attention.py,benchmarks/bench_trtllm_fmha.py). While the implementations vary slightly (0.1 scaling factor vs. no scaling), consolidating into a shared utility would improve maintainability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/README.md(1 hunks)benchmarks/routines/attention.py(10 hunks)benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/routines/attention.py (4)
tests/attention/test_trtllm_gen_attention.py (1)
to_float8(39-45)benchmarks/routines/gemm.py (1)
to_float8(160-166)benchmarks/bench_trtllm_fmha.py (1)
to_float8(62-68)flashinfer/cudnn/prefill.py (1)
cudnn_batch_prefill_with_kv_cache(554-720)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (9)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
176-184: LGTM! Backend expansion is well-documented.The addition of
cudnn-nativeacross compute capabilities 8.0–12.0 is consistent with the broader PR changes. The clarifying comment helps distinguish between the two cuDNN-related backends.benchmarks/routines/attention.py (7)
7-22: LGTM! Robust cuDNN import handling.The import logic correctly handles both missing imports and library loading failures, with appropriate filtering to re-raise unexpected OSErrors.
107-107: LGTM! Consistent backend addition.The
cudnn-nativechoice aligns with the backend expansion across the PR.
700-706: LGTM! Appropriate tolerance relaxation for FP8.The relaxed tolerances (rtol=5e-1, atol=1e-1) are reasonable given the reduced precision of FP8 data types.
966-993: LGTM! Proper FP8 scale handling for multiple backends.The code correctly maintains two conversion paths:
- Scalar scales for non-cuDNN backends (fa2, fa3, trtllm)
- Tensor scales for cuDNN-related backends
The separate handling of
k_cache_cudnn/v_cache_cudnnvs.kv_cacheaccommodates different backend tensor layout requirements.
1027-1115: LGTM! Backend-specific handling is correct.The code properly accommodates different backend requirements:
cudnnuses the wrapper API with NHD layout and tensor scalestrtllm-nativecomputes combined BMM scales with defensive defaultscudnn-nativeuses the direct API with tensor scalesThe defensive scale defaults (lines 1074-1076) prevent potential None-related errors.
1148-1161: LGTM! Flexible reference backend selection.The fallback logic enables testing when FA2 is unavailable (e.g., for FP8 scenarios). The priority order (cudnn → cudnn-native → trtllm-gen → trtllm-native) is sensible, and the informative logging at line 1159 aids debugging.
1180-1180: LGTM! Error message now reflects actual reference backend.This change correctly reports the actual reference backend used (which may not always be FA2), improving error message accuracy.
benchmarks/README.md (1)
120-120: LGTM! Documentation updated to reflect new backend.The addition of
cudnn-nativeto the backends documentation is consistent with the code changes across the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/cudnn/prefill.py (1)
75-129: Fix graph-cache key: includek_cache.dtype,v_cache.dtype, ando_data_type.
Right now the cache key includesq.dtype(Line 116) but notk_cache.dtype/v_cache.dtype, ando_data_typeis accepted but ignored—so_build_prefill_graphcan be incorrectly reused across calls with different KV dtypes or output dtype.def _sdpa_prefill_key_fn( @@ - key = ( + key = ( graph_b, q.dim(), q.dtype, + k_cache.dtype, + v_cache.dtype, k_cache.dim(), @@ page_size, + o_data_type, ) return key
♻️ Duplicate comments (2)
flashinfer/cudnn/prefill.py (2)
164-182: Don’t default FP8 output dtype to FP8 implicitly; validate/require fp16/bf16 for FP8 Q/KV.
o_data_type = q.dtype(Line 171-173; also Line 649-651) will makecudnn_o_data_typeFP8 whenqis FP8, and you then forceO.set_data_type(cudnn_o_data_type)(Line 430-435). If cuDNNsdpa_fp8expects/produces fp16/bf16 outputs by design, this can be invalid or silently wrong.cuDNN 9.18 SDPA FP8 (sdpa_fp8): what output tensor dtypes are supported (FP16/BF16 only vs FP8 output supported)? Is output dtype configurable via graph tensor data_type?Also applies to: 430-435, 649-655
200-252: FP8 scale tensors must be required/validated; current conditionalvar_mapcan crash graph execution.
Whenqis FP8, the graph always creates scale tensors and sets UIDs forQ_SCALE_UID/K_SCALE_UID/V_SCALE_UID/S_SCALE_UID/S_DESCALE_UID/O_SCALE_UID(Line 200-252), but_batch_prefill_with_kv_cacheonly populates those UIDs if the caller passed the corresponding tensors (Line 534-544). Missing any scale will typically fail insidegraph.execute()with hard-to-debug cuDNN errors.Suggested guard (place near the start of
_batch_prefill_with_kv_cache, before_build_prefill_graph):def _batch_prefill_with_kv_cache( @@ ) -> tuple[torch.Tensor, torch.Tensor]: + # Validate required FP8 scales early to avoid opaque cuDNN execution failures. + if q.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + missing = [name for name, t in (("q_scale", q_scale), ("k_scale", k_scale), ("v_scale", v_scale)) if t is None] + if missing: + raise ValueError(f"FP8 prefill requires scale tensors: missing {', '.join(missing)}.") + graph, tensors = _build_prefill_graph( q=q, @@ )Also consider verifying
q_scale.device == q.deviceand shape(1,1,1,1)here.Also applies to: 534-544, 656-681
🧹 Nitpick comments (2)
flashinfer/cudnn/prefill.py (2)
339-419:sdpa_fp8(..., generate_stats=True)ignoresreturn_lse; consider aligning behavior or enforcingreturn_lse=Truefor FP8.
Ifreturn_lse=False, you still request stats/amax outputs, but you only mark/returnStatswhenreturn_lseis true later. This is likely fine but extra overhead and a potential mismatch with the public doc (“return_lse must be True”).
96-97: Remove unused_sdpa_prefill_key_fn(..., o_data_type=...)param if you won’t key on it.
Right now it’s unused (ruff ARG001) and increases the chance callers assume it affects caching. If you apply the cache-key fix above, keep it; otherwise drop it.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/cudnn/prefill.py(19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/cudnn/prefill.py (1)
flashinfer/cudnn/decode.py (2)
_create_cudnn_handle(21-26)UIDs(30-49)
🪛 Ruff (0.14.8)
flashinfer/cudnn/prefill.py
96-96: Unused function argument: o_data_type
(ARG001)
165-165: Avoid specifying long messages outside the exception class
(TRY003)
180-182: Avoid specifying long messages outside the exception class
(TRY003)
343-343: Ambiguous variable name: O
(E741)
378-378: Ambiguous variable name: O
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/cudnn/prefill.py (3)
20-28: Per-device dummy scale cache looks good; consider device key normalization / thread-safety only if needed.
This resolves the multi-GPU wrong-device tensor risk from the old global singleton; the(1,1,1,1)shape also matches the graph tensors.
31-37: Handle return matchesdecode.py; good consistency.
Returning the handle makes the helper usable in more contexts and matches the pattern inflashinfer/cudnn/decode.py.
64-72: No UID range collision detected. The new scale/amax UIDs (150–155 and 160–161) in prefill.py do not overlap with existing UIDs in decode.py (which uses ranges 0–3, 50–52, 100–101, 200–202, 1000–1001) or other modules in the codebase. The numeric ranges are cleanly separated.
cudnn implementation for sdpa fp8
📌 Description
Allows cudnn SDPA be called when q and kv are both fp8
Requires cudnn 9.18.0
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.