Skip to content

Conversation

@nvpohanh
Copy link
Contributor

@nvpohanh nvpohanh commented Nov 28, 2025

📌 Description

#2111 already enabled Hopper FA3 FP8 attention in prefill.py. This is just a follow-up PR to make the same change in decode.py because decode.py actually uses prefill kernels.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added output data type parameter for improved control over batch decode output precision.
  • Improvements

    • Enhanced FP8 support and automatic backend selection for optimal performance across hardware configurations.
    • Added dtype validation to ensure consistency between planning and execution phases.
    • Optimized scaling operations to improve efficiency.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 28, 2025

Walkthrough

The changes introduce output data type (o_data_type) parameter support across batch decode planning and execution, enabling explicit control over output tensor dtypes with automatic defaulting from query dtype when not specified. FP8 scale tensors are now properly threaded through backend paths, and dtype validation ensures output tensors match planned specifications.

Changes

Cohort / File(s) Summary
Batch decode output type support
flashinfer/decode.py
Added o_data_type parameter to public plan() signature with automatic defaulting from q_data_type. Updated internal logic to thread o_data_type through module creation and backend execution paths (infa2/fa2). FP8 scale tensors now passed explicitly. Out tensor allocation uses o_data_type for dtype determination. Backend selection logic updated with determine_attention_backend import for automatic backend choosing.
Module validation & FP8 handling
flashinfer/jit/attention/modules.py
Added validation assertions in gen_batch_prefill_module to enforce backend is "fa2" or "fa3", and output dtype cannot be FP8 variants. Early error handling for unsupported configurations.
Prefill dtype consistency & documentation
flashinfer/prefill.py
Expanded FP8 data type guidance in docstrings (torch.float16 or torch.bfloat16 for FP8 inputs). Added q_scale parameter documentation for run(). Implemented dtype consistency validation ensuring out.dtype matches planned o_data_type. Conditional v_scale application skips scaling when None or exactly 1.0.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • flashinfer/decode.py: Pay close attention to o_data_type parameter threading through all code paths, FP8 scale tensor handling, backend selection logic, and out tensor allocation updates.
  • Dtype validation logic: Verify consistency checks in run() methods across decode.py and prefill.py properly validate against _cached_o_data_type and raise appropriate errors.
  • FP8 scale tensor passing: Confirm fp8_scale_q/k/v are correctly passed through backend-specific paths and not left as None where required.

Poem

🐰 Output types now shine so clear,
With dtype precision, far and near!
FP8 scales flow through every path,
Validation's guard prevents the wrath.
O_data_type threads with graceful care,
Tensors shaped with perfect flair! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description contains only the template placeholder text with no actual content describing the changes, objectives, or rationale for the modifications. Add a clear description of the changes made to enable Hopper FA3 FP8 attention in decode.py, explain why it was needed, and reference the related PR #2111.
✅ Passed checks (2 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title accurately reflects the main change in the PR: enabling FP8 attention support for Hopper FA3 architecture in decode operations.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 28, 2025

please refer to #2111 where we refactored fa3 and exposed the fp8 interface to python

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 4 times, most recently from ede67a3 to a8d9e6a Compare December 2, 2025 03:55
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 2 times, most recently from ee77217 to 83cfce9 Compare December 8, 2025 08:44
@nvpohanh nvpohanh marked this pull request as ready for review December 11, 2025 06:17
@nvpohanh nvpohanh changed the title Enable Hopper FA3 FP8 attention Enable Hopper FA3 FP8 attention in decode.py Dec 11, 2025
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from 83cfce9 to 09a1ece Compare December 11, 2025 06:34
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
flashinfer/prefill.py (1)

2110-2114: Stronger out-dtype validation; consider tightening the error message

The explicit check that out.dtype matches the planned o_data_type enforces the plan/run contract and prevents silent dtype mismatches when callers reuse an out buffer. This is a solid safety improvement.

Ruff’s TRY003 warning about long exception messages could be addressed by shortening the message slightly or moving it into a shared constant/helper, but that’s stylistic and not functionally required.

flashinfer/decode.py (1)

1306-1313: Minor suggestion: align out shape check with allocation expression

Right now out is allocated with q.shape[:-1] + v_cache.shape[-1:] but validated against q.shape. Those are equal for today’s kernels (q and v share head_dim), but if q/v head dims ever diverge, the validation would become inconsistent. Consider using the same expression in both places for future-proofing.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dc0ade7 and 09a1ece.

📒 Files selected for processing (3)
  • flashinfer/decode.py (17 hunks)
  • flashinfer/jit/attention/modules.py (1 hunks)
  • flashinfer/prefill.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (5)
flashinfer/utils.py (3)
  • determine_attention_backend (450-492)
  • canonicalize_torch_dtype (241-249)
  • is_float8 (158-159)
flashinfer/logits_processor/types.py (2)
  • dtype (126-130)
  • device (119-123)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
flashinfer/attention.py (1)
  • plan (71-136)
flashinfer/pod.py (2)
  • plan (265-434)
  • plan (800-1014)
🪛 Ruff (0.14.8)
flashinfer/prefill.py

2112-2114: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (10)
flashinfer/jit/attention/modules.py (1)

987-992: Good preflight validation for backend and FP8 output dtype

The explicit backend whitelist and dtype_o FP8 guard make the FA2/FA3 path fail fast on unsupported configurations and align with the FP8 constraints used elsewhere in prefill/decode. Looks correct and consistent with the intended Hopper FA3 FP8 support.

flashinfer/prefill.py (3)

1704-1705: FP8 output dtype guidance in docstrings is clear and aligned

The added note that o_data_type for FP8 inputs should typically be torch.float16 or torch.bfloat16 matches the actual capabilities and helps users avoid unsupported FP8 outputs. No further changes needed here.

Also applies to: 2656-2657


2080-2082: q_scale documentation matches runtime usage

The new q_scale description correctly reflects how it’s applied (folded into sm_scale for FP8 BMM1). This keeps the public API understandable for FP8 users.


2268-2273: Conditional v_scale application is correct and avoids unnecessary work

Applying v_scale only when it’s not None and not 1.0 preserves previous behavior while saving a redundant multiply on the common default. The FP8 branch (cast to fp32, scale, cast back) is also consistent with typical scaling practice.

flashinfer/decode.py (6)

51-76: Backend auto-selection and FA2/FA3 plan args look consistent and safe

Importing and using determine_attention_backend to specialize self._backend when it is "auto" in the tensor-core path, and then branching FA2-only plan arguments (fixed_split_size, disable_split_kv, num_colocated_ctas) while keeping FA3 with the shorter signature, matches the expected separation of FA2/FA3 interfaces and mirrors the logic in fast_decode_plan. This keeps decode aligned with prefill and should enable FA3 on Hopper cleanly.

Also applies to: 1042-1050, 2635-2660


720-733: Passing backend into gen_customize_batch_prefill_module is the right direction

Wiring backend explicitly into gen_customize_batch_prefill_module for the tensor-core JIT path aligns decode with the prefill side and makes backend selection explicit at module generation time. Assuming the JIT generator’s signature matches this order, this is a straightforward and correct extension of the existing JIT flow.


838-840: o_data_type threading and validation are coherent

The new o_data_type parameter is:

  • Defaulted to q_data_type when not provided, then canonicalized.
  • Cached as _cached_o_data_type and threaded into get_trtllm_gen_decode_module, get_batch_prefill_module, and get_batch_decode_module.
  • Used at run time for both allocation and check_shape_dtype_device validation of out.

This is consistent with the rest of the dtype handling and gives callers explicit control over output dtype (including FP8-input → FP16/BF16-output scenarios) without breaking older call sites that omit o_data_type.

Also applies to: 886-889, 964-977, 985-988, 1025-1035, 1095-1104, 1306-1313


1341-1361: FP8 scale wiring and v_scale optimization look correct

Extracting fp8_scale_q/k/v from *args only when q is float8 and only for the non-JIT tensor-core path keeps the existing API surface intact while enabling FA3/FA2 FP8 usage. Passing these explicitly into paged_run matches the extended kernel signature, and the updated v_scale guard (v_scale is not None and v_scale != 1.0) plus is_float8(out)-based cast behavior are sensible micro-optimizations that preserve numerical behavior.

Also applies to: 1413-1418


998-1005: trtllm-gen decode integration with o_data_type remains consistent

Adding o_data_type to the get_trtllm_gen_decode_module cache key while still using the same paged_run argument layout (including workspace_size, block_tables, kv_lens_buffer, etc.) keeps the trtllm-gen path coherent with FA2/FA3 and ensures different output dtypes don’t collide in the module cache. The subsequent _paged_run call continues to receive the expected scales and workspace sizing.

Also applies to: 1364-1373


2523-2590: fast_decode_plan argument construction now mirrors main plan

The tensor-core branch in fast_decode_plan now builds the same base args list as BatchDecodeWithPagedKVCacheWrapper.plan and appends FA2-only arguments under self._backend == "fa2". This keeps the “fast” path in sync with the standard planner and reduces the risk of FA3/FA2 divergence for multistep decode.

Also applies to: 2635-2660

@nvpohanh
Copy link
Contributor Author

@yzh119 Could you review this PR? Thanks!

Let me know if you think I should add some tests for this. If so, please point me to the test file where I should add/extend the tests. Thanks!

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nvpohanh, I'd say the appropriate place to put unit tests should be test_hopper.py or test_hopper_fp8_attention.py.

Since we are adding fa3 as a backend for the first time in BatchDecodeWithPagedKVCacheWrapper, we may need to write new tests analogous to the prefill ones in there.

Only needed when ``use_cuda_graph`` is ``True``.
backend : str
The implementation backend, could be ``auto``/``fa2`` or ``trtllm-gen``. Defaults to ``auto``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add fa3 here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants