-
Notifications
You must be signed in to change notification settings - Fork 593
Enable Hopper FA3 FP8 attention in decode.py #2148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enable Hopper FA3 FP8 attention in decode.py #2148
Conversation
WalkthroughThe changes introduce output data type (o_data_type) parameter support across batch decode planning and execution, enabling explicit control over output tensor dtypes with automatic defaulting from query dtype when not specified. FP8 scale tensors are now properly threaded through backend paths, and dtype validation ensures output tensors match planned specifications. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
please refer to #2111 where we refactored fa3 and exposed the fp8 interface to python |
ede67a3 to
a8d9e6a
Compare
ee77217 to
83cfce9
Compare
Signed-off-by: Po-Han Huang <[email protected]>
83cfce9 to
09a1ece
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
flashinfer/prefill.py (1)
2110-2114: Stronger out-dtype validation; consider tightening the error messageThe explicit check that
out.dtypematches the plannedo_data_typeenforces the plan/run contract and prevents silent dtype mismatches when callers reuse anoutbuffer. This is a solid safety improvement.Ruff’s TRY003 warning about long exception messages could be addressed by shortening the message slightly or moving it into a shared constant/helper, but that’s stylistic and not functionally required.
flashinfer/decode.py (1)
1306-1313: Minor suggestion: alignoutshape check with allocation expressionRight now
outis allocated withq.shape[:-1] + v_cache.shape[-1:]but validated againstq.shape. Those are equal for today’s kernels (q and v share head_dim), but if q/v head dims ever diverge, the validation would become inconsistent. Consider using the same expression in both places for future-proofing.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/decode.py(17 hunks)flashinfer/jit/attention/modules.py(1 hunks)flashinfer/prefill.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (5)
flashinfer/utils.py (3)
determine_attention_backend(450-492)canonicalize_torch_dtype(241-249)is_float8(158-159)flashinfer/logits_processor/types.py (2)
dtype(126-130)device(119-123)include/flashinfer/trtllm/common.h (1)
device(83-90)flashinfer/attention.py (1)
plan(71-136)flashinfer/pod.py (2)
plan(265-434)plan(800-1014)
🪛 Ruff (0.14.8)
flashinfer/prefill.py
2112-2114: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (10)
flashinfer/jit/attention/modules.py (1)
987-992: Good preflight validation for backend and FP8 output dtypeThe explicit
backendwhitelist anddtype_oFP8 guard make the FA2/FA3 path fail fast on unsupported configurations and align with the FP8 constraints used elsewhere in prefill/decode. Looks correct and consistent with the intended Hopper FA3 FP8 support.flashinfer/prefill.py (3)
1704-1705: FP8 output dtype guidance in docstrings is clear and alignedThe added note that
o_data_typefor FP8 inputs should typically betorch.float16ortorch.bfloat16matches the actual capabilities and helps users avoid unsupported FP8 outputs. No further changes needed here.Also applies to: 2656-2657
2080-2082: q_scale documentation matches runtime usageThe new
q_scaledescription correctly reflects how it’s applied (folded intosm_scalefor FP8 BMM1). This keeps the public API understandable for FP8 users.
2268-2273: Conditional v_scale application is correct and avoids unnecessary workApplying
v_scaleonly when it’s notNoneand not1.0preserves previous behavior while saving a redundant multiply on the common default. The FP8 branch (cast to fp32, scale, cast back) is also consistent with typical scaling practice.flashinfer/decode.py (6)
51-76: Backend auto-selection and FA2/FA3 plan args look consistent and safeImporting and using
determine_attention_backendto specializeself._backendwhen it is"auto"in the tensor-core path, and then branching FA2-only plan arguments (fixed_split_size,disable_split_kv,num_colocated_ctas) while keeping FA3 with the shorter signature, matches the expected separation of FA2/FA3 interfaces and mirrors the logic infast_decode_plan. This keeps decode aligned with prefill and should enable FA3 on Hopper cleanly.Also applies to: 1042-1050, 2635-2660
720-733: Passingbackendintogen_customize_batch_prefill_moduleis the right directionWiring
backendexplicitly intogen_customize_batch_prefill_modulefor the tensor-core JIT path aligns decode with the prefill side and makes backend selection explicit at module generation time. Assuming the JIT generator’s signature matches this order, this is a straightforward and correct extension of the existing JIT flow.
838-840:o_data_typethreading and validation are coherentThe new
o_data_typeparameter is:
- Defaulted to
q_data_typewhen not provided, then canonicalized.- Cached as
_cached_o_data_typeand threaded intoget_trtllm_gen_decode_module,get_batch_prefill_module, andget_batch_decode_module.- Used at run time for both allocation and
check_shape_dtype_devicevalidation ofout.This is consistent with the rest of the dtype handling and gives callers explicit control over output dtype (including FP8-input → FP16/BF16-output scenarios) without breaking older call sites that omit
o_data_type.Also applies to: 886-889, 964-977, 985-988, 1025-1035, 1095-1104, 1306-1313
1341-1361: FP8 scale wiring and v_scale optimization look correctExtracting
fp8_scale_q/k/vfrom*argsonly whenqis float8 and only for the non-JIT tensor-core path keeps the existing API surface intact while enabling FA3/FA2 FP8 usage. Passing these explicitly intopaged_runmatches the extended kernel signature, and the updatedv_scaleguard (v_scale is not None and v_scale != 1.0) plusis_float8(out)-based cast behavior are sensible micro-optimizations that preserve numerical behavior.Also applies to: 1413-1418
998-1005: trtllm-gen decode integration witho_data_typeremains consistentAdding
o_data_typeto theget_trtllm_gen_decode_modulecache key while still using the samepaged_runargument layout (includingworkspace_size,block_tables,kv_lens_buffer, etc.) keeps the trtllm-gen path coherent with FA2/FA3 and ensures different output dtypes don’t collide in the module cache. The subsequent_paged_runcall continues to receive the expected scales and workspace sizing.Also applies to: 1364-1373
2523-2590:fast_decode_planargument construction now mirrors mainplanThe tensor-core branch in
fast_decode_plannow builds the same baseargslist asBatchDecodeWithPagedKVCacheWrapper.planand appends FA2-only arguments underself._backend == "fa2". This keeps the “fast” path in sync with the standard planner and reduces the risk of FA3/FA2 divergence for multistep decode.Also applies to: 2635-2660
|
@yzh119 Could you review this PR? Thanks! Let me know if you think I should add some tests for this. If so, please point me to the test file where I should add/extend the tests. Thanks! |
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @nvpohanh, I'd say the appropriate place to put unit tests should be test_hopper.py or test_hopper_fp8_attention.py.
Since we are adding fa3 as a backend for the first time in BatchDecodeWithPagedKVCacheWrapper, we may need to write new tests analogous to the prefill ones in there.
| Only needed when ``use_cuda_graph`` is ``True``. | ||
| backend : str | ||
| The implementation backend, could be ``auto``/``fa2`` or ``trtllm-gen``. Defaults to ``auto``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add fa3 here?
📌 Description
#2111 already enabled Hopper FA3 FP8 attention in
prefill.py. This is just a follow-up PR to make the same change indecode.pybecausedecode.pyactually uses prefill kernels.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.