Skip to content

Conversation

@nvpohanh
Copy link
Contributor

@nvpohanh nvpohanh commented Nov 28, 2025

Purpose

Enable the best-performing Hopper FP8 attention kernels in FlashInfer backend.

In Prefill stage, the FlashInfer FA3 backend will be used. In Decode stage, the FlashInfer XQA backend (using trtllm interface) will be used.

However, things are complicated because different backends have different requirements. The most noticeable difference is that when kv-cache dtype is FP8, FA3 backend requires the query to be also in FP8, but XQA backend requires the query to be in FP16/BF16. Therefore, we cannot apply query quantization outside of the attention custom op. Instead, we must apply the query quantization inside the attention custom op's forward().

Changes

In vllm/utils/flashinfer.py:

  • Extend all the SM100 checks to SM100/103 (is_sm100f_supported())
  • Change supports_trtllm_attention() to check_trtllm_attention_support() which now returns:
    • True when we "must" use TRTLLM with the reason.
    • False when we "must not" use TRTLLM with the reason.
    • None when we can use TRTLLM but not a must.
  • Change use_trtllm_attention() such that:
    • It decides whether to use TRTLLM based on check_trtllm_attention_support() and force_use_trtllm_attention().
    • If the decision does not match force_use_trtllm_attention(), print a warning with a reason. Otherwise, print an info with a reason.

In vllm/v1/attention/backends/flashinfer.py:

  • Change the decision of whether to use TRTLLM from runtime to initialization time.
  • Split q_data_type in metadata into q_data_type_prefill and q_data_type_decode.
    • Generally, these two will be the same, except on SM90 when FP8 kv-cache is used.
  • Add a few classmethods to avoid duplicated code between __init__() and get_cudagraph_support() (the latter is a classmethod).
  • Update args passed into FlashInfer interface to support out_dtype and pass in q_scale.

Requires flashinfer-ai/flashinfer#2148 from FlashInfer.

Test Plan

Run accuracy + performance tests on the cross-product of the following attributes:

  • {H200, B200}
  • {Qwen3-32B-FP8, GPT-OSS-120b}
  • {TP1 conc16, TP1 conc512}
  • {auto kv-cache dtype, fp8 kv-cache}
  • {default attn backend, FLASHINFER attn backend, FLASHINFER attn backend + disable TRTLLM attn}
Script to run TBA

Test Result

TBA


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 4 times, most recently from ec4797a to 2ded1b1 Compare December 8, 2025 05:09
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from 2ded1b1 to a691699 Compare December 8, 2025 08:44
@mergify
Copy link

mergify bot commented Dec 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @nvpohanh.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 10, 2025
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from a691699 to 94b2c00 Compare December 11, 2025 02:18
@mergify mergify bot removed the needs-rebase label Dec 11, 2025
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch 3 times, most recently from 175da71 to bfabf86 Compare December 11, 2025 06:57
Enable the best-performing Hopper FP8 attention kernels in FlashInfer
backend.

In Prefill stage, the FlashInfer FA3 backend will be used. In Decode
stage, the FlashInfer XQA backend (using trtllm interface) will be used.

However, things are complicated because different backends have
different requirements. The most noticeable difference is that when
kv-cache dtype is FP8, FA3 backend requires the query to be also in FP8,
but XQA backend requires the query to be in FP16/BF16. Therefore, we
cannot apply query quantization outside of the attention custom op.
Instead, we must apply the query quantization inside the attention
custom op's forward().

In `vllm/utils/flashinfer.py`:

- Extend all the SM100 checks to SM100/103 (`is_sm100f_supported()`)
- Change `supports_trtllm_attention()` to
  `check_trtllm_attention_support()` which now returns:
  - True when we "must" use TRTLLM with the reason.
  - False when we "must not" use TRTLLM with the reason.
  - None when we can use TRTLLM but not a must.
- Change `use_trtllm_attention()` such that:
  - It decides whether to use TRTLLM based on
    `check_trtllm_attention_support()` and
`force_use_trtllm_attention()`.
  - If the decision does not match `force_use_trtllm_attention()`, print
    a warning with a reason. Otherwise, print an info with a reason.

In `vllm/v1/attention/backends/flashinfer.py`:

- Change the decision of whether to use TRTLLM from runtime to
  initialization time.
- Split `q_data_type` in metadata into `q_data_type_prefill` and
  `q_data_type_decode`.
  - Generally, these two will be the same, except on SM90 when FP8
    kv-cache is used.
- Add a few classmethods to avoid duplicated code between `__init__()`
  and `get_cudagraph_support()` (the latter is a classmethod).
- Update args passed into FlashInfer interface to support out_dtype and
  pass in q_scale.

Signed-off-by: Po-Han Huang <[email protected]>
@nvpohanh nvpohanh force-pushed the dev-nvpohanh-hopper-fp8-attention branch from bfabf86 to 47f6d1e Compare December 11, 2025 07:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant