Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 19, 2025

📌 Description

This PR refactors the out-dated fa3 codebase, more specifically, for page_size>1, the page offset calculation is performed inside the kernel, without the need of a standalone function call to block_sparse_indices_to_vector_sparse_offsets, and optimize the kv_offset calculation with prefetching and shuffling.

This PR also fixes the failed unittest on hopper.

However, the FA3 structure in our codebase is still terrible outdated without important features such as IntraWGOverlap and RescaleOBeforeGemm, will follow up soon in a later PR.

🔍 Related Issues

This PR should fixes #1647

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Enforced consistent K/V page-strides and stride_n for sparse paged attention with runtime validation.
  • New Features

    • Exposed page-interval stride metadata for paged KV caches.
    • Expanded FP8 support: per-head/per-tensor scale propagation and configurable output dtype; added ragged-KV execution path and benchmarks.
  • Refactor

    • Unified sparse attention to a page-based addressing model and simplified workspace/epilogue behavior; removed legacy block-to-vector conversion helpers and related APIs.
  • Tests

    • Added FP8 ragged/paged tests, expanded page-size coverage, and removed obsolete vector-sparse unit test.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 19, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds page-level K/V page_stride fields and runtime stride validations for sparse paged KV cache; replaces block→vector-sparse conversion with page-table/manual page addressing; threads FP8 scale tensors and output dtype through JIT/host/kernel paths; removes block-sparse helper APIs and updates tests/benchmarks accordingly.

Changes

Cohort / File(s) Summary
Paged params struct additions
csrc/batch_prefill_sm90_customize_config.jinja, include/flashinfer/attention/hopper/default_params.cuh
Added k_page_stride and v_page_stride (int64_t) to paged params structs to record page-level strides for K and V.
Kernel/host wiring & runtime checks
csrc/batch_prefill_sm90.cu, include/.../prefill_sm90.cuh
Populate and propagate k_page_stride, v_page_stride (and page_size where applicable) into kernel/mainloop args; add runtime validations enforcing stride consistency on sparse paged paths.
Sparse mainloop & FP8 sparse-load refactor
include/flashinfer/attention/hopper/sparse_mainloop.cuh, include/.../quantization/mainloop_sparse_load.cuh
Replace block-sparse gather with page-table/manual page-based K/V addressing: add explicit k/v stride and page fields, implement page-aware load helpers, cooperative prefetching and cp_async-based page loads.
Removal of block→vector conversion (C/CUDA)
csrc/flashinfer_page_binding.cu, csrc/page.cu, include/flashinfer/page.cuh
Deleted declaration/implementation/FFI export for block_sparse_indices_to_vector_sparse_offsets (kernel and host launcher removed).
Python helper & tests removed
flashinfer/page.py, tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py
Removed Python wrapper block_sparse_indices_to_vector_sparse_offsets and its unit test.
Prefill & sparse Python unification & FP8 wiring
flashinfer/prefill.py, flashinfer/sparse.py
Dropped fa3 vector-sparse branches and buffers; unified on paged buffers; threaded FP8 scale tensors (tensor+scalar split) and added o_data_type propagation; updated reset_workspace_buffer signatures.
FP8 / Ragged KV additions
csrc/batch_prefill_fp8_sm90.cu, csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja, include/.../quantization/prefill_sm90.cuh
Added Ragged KV dispatch declarations/paths and FP8 ragged/paged kernel variants and instantiation wiring; forward-declared dispatched templates.
JIT module arg expansion & variant helpers
flashinfer/jit/attention/modules.py, include/flashinfer/attention/hopper/variants.cuh, include/flashinfer/attention/variant_helper.cuh, flashinfer/jit/attention/variants.py
Expanded JIT/module args to include optional maybe_scale_* tensors and scalar scales; added helpers (get_q/k/v/variant_scale_pv) and v-scale support in variant helpers and sinks.
Removed block-sparse helpers & upcast utilities
include/flashinfer/attention/hopper/block_sparse_gather.cuh
Entire header removed: BlockSparseIndexedGather, CustomStride, make_block_sparse_tensor and related upcast/layout utilities deleted.
Epilogue sync & kernel traits changes
include/flashinfer/attention/hopper/epilogue.cuh, include/flashinfer/attention/hopper/kernel_traits.cuh
Unified epilogue barriers, added write_warp_idx param to CollectiveEpilogue::store, and conditionalized NUM_PRODUCER_THREADS on USE_TMA_LOAD_KV.
Tests & benchmarks
tests/attention/test_batch_prefill_kernels.py, tests/attention/test_hopper.py, tests/attention/test_hopper_fp8_attention.py, benchmarks/bench_hopper_fp8_attention.py
Adjusted LSE buffering and pre-allocated paths, added page_size parametrization, removed kv_indices padding, introduced FP8 quantization helpers/tests and FP8 benchmarks.
Misc small fixes
flashinfer/triton/kernels/cascade.py, tests/attention/test_batch_prefill_kernels.py, scripts/task_jit_run_tests_part2.sh
Widened loop index to 64-bit in merge kernel; tests now capture/reuse LSE buffer; test script adjusted to run a different test module.

Sequence Diagram(s)

sequenceDiagram
    participant Tests as Python tests/bench
    participant Wrapper as Wrapper (flashinfer/prefill.py / sparse.py)
    participant JIT as JIT modules (flashinfer/jit/...)
    participant Host as Host planner (csrc/* + headers)
    participant Kernel as CUDA device mainloop

    Note over Tests,Kernel: New paged/ragged + FP8 flow
    Tests->>Wrapper: run(plan/run args, paged_kv_indptr, paged_kv_indices, o_data_type?, fp8_scales?)
    Wrapper->>JIT: plan/run(..., paged buffers, maybe_scale_*, o_data_type)
    JIT->>Host: module args include k_page_stride, v_page_stride, k_stride_n, v_stride_n, page_size, scale tensors/scalars
    Host->>Kernel: launch kernel with (K_ptr, V_ptr, kv_indices, k_page_stride, v_page_stride, page_size, scales...)
    Kernel->>Kernel: compute page_idx via divmod(page_size) → page offsets
    Kernel->>Kernel: prefetch / cp_async using page strides and stride_n
    Kernel->>Kernel: MMA / epilogue (barrier/sync) → write output (obey o_data_type & scales)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

  • Areas needing extra attention:
    • include/flashinfer/attention/hopper/sparse_mainloop.cuh and include/.../quantization/mainloop_sparse_load.cuh — page divmod/fastdiv correctness, cp_async prefetching, bounds/predicate logic, barrier ordering.
    • csrc/batch_prefill_sm90.cu — propagation of k_page_stride/v_page_stride and added runtime consistency checks for sparse paths.
    • flashinfer/prefill.py and flashinfer/sparse.py — FP8 scale propagation, o_data_type handling, and removal of vector-sparse branches (ensure no dangling references).
    • Removals: ensure no remaining references to block_sparse_indices_to_vector_sparse_offsets or block_sparse_gather utilities.
    • Tests/benchmarks: validate FP8 quantization helpers, tolerances, and new parameterizations.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • cyx-6
  • djmmoss
  • nvmbreughe

Poem

🐇 I hop through pages, counting stride by stride,
K and V march steady, always side by side.
Old vector tricks are gone, page tables hum and sing,
FP8 scales tucked in — now kernels leap and spring.
A rabbit cheers — the prefill bells ring!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.35% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly describes the main changes: refactoring FA3 codebase and fixing Hopper unit tests, which aligns with the provided description and file changes.
Description check ✅ Passed The PR description covers main objectives, links related issue #1647, and confirms pre-commit checks and testing were completed. However, some sections of the template are incomplete.
Linked Issues check ✅ Passed The PR addresses issue #1647 by removing the block_sparse_indices_to_vector_sparse_offsets function and moving page-offset calculation into the kernel, eliminating the _vector_sparse_indices_buffer sizing issue for page_size>1 cases.
Out of Scope Changes check ✅ Passed All changes are in-scope: FA3 refactoring (removing vector-sparse buffers), kernel updates for page-stride handling, test fixes, and benchmark additions directly support the stated objectives without unrelated modifications.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the efficiency and correctness of the Flash Attention v3 (FA3) implementation, particularly for paged Key-Value (KV) caches with page sizes greater than one. By integrating page offset calculations directly into the kernel and optimizing KV offset handling with prefetching and shuffling, the codebase becomes more streamlined and performant. A critical bug affecting Hopper unittests has also been resolved, ensuring robust operation on the target architecture. These changes collectively contribute to a more optimized and reliable sparse attention mechanism.

Highlights

  • FA3 Codebase Refactoring: The pull request refactors the Flash Attention v3 (FA3) codebase, specifically for page_size > 1, by moving the page offset calculation directly inside the kernel, eliminating the need for a separate block_sparse_indices_to_vector_sparse_offsets function.
  • KV Offset Optimization: Optimizations for KV offset calculation have been implemented, incorporating prefetching and shuffling techniques to enhance performance.
  • Hopper Unittest Fix: A previously failing unittest on Hopper architecture has been identified and fixed, ensuring correctness on this specific hardware.
  • Parameter Updates for Paged KV Cache: New parameters, k_page_stride and v_page_stride, have been introduced to store the stride between pages for sparse paged KV cache, along with assertions to ensure K and V have consistent page and stride_n values for efficiency.
  • Removal of Redundant Buffer Allocations: The _vector_sparse_indices_buffer and _vector_sparse_indptr_buffer have been removed from Python-side classes (BatchPrefillWrapper, BatchPrefillWrapperSparse, BatchDecodeWrapperSparse) as their functionality is now handled directly within the kernel.
  • Test Updates and File Management: The test_batch_prefill_kernels.py file has been updated to correctly handle lse (log-sum-exp) buffer allocation when return_lse is enabled. Additionally, a test file related to the removed block_sparse_indices_to_vector_sparse_offsets function has been removed, and another test file has been renamed for better organization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yzh119
Copy link
Collaborator Author

yzh119 commented Nov 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been created, and the CI pipeline #38780168 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant refactoring of the FA3 codebase. It removes the standalone block_sparse_indices_to_vector_sparse_offsets function and moves the page offset calculation directly into the CUDA kernel, which is a great simplification. The changes also include an optimization for kv_offset calculation using prefetching and shuffling, which should improve performance. The code removal across C++, Python, and header files is consistent and clean. I've found a couple of minor areas for code improvement to reduce redundancy, but overall the changes look solid and well-implemented.

int d_idx = get<1>(coord);
int kv_idx = kv_base_idx + kv_offset;

bool guard = kv_idx < kv_len && kv_offset < valid_tile_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The guard condition can be simplified. The check kv_idx < kv_len is redundant when use_predicate is true, as it's already implied by kv_offset < valid_tile_size. When use_predicate is false, valid_tile_size is CTA_KV, and kv_offset is always less than CTA_KV, so the guard is not needed for non-last tiles anyway. You can simplify this to just kv_offset < valid_tile_size.

        bool guard = kv_offset < valid_tile_size;

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
flashinfer/prefill.py (1)

2109-2156: Paged KV run argument rewiring is reasonable; verify trtllm cum_seq_lens_kv semantics

Using _paged_kv_indptr_buf / _paged_kv_indices_buf directly in run_args keeps the Python wrapper aligned with the new paged-KV FFI signature, and _qo_indptr_buf is a natural fit for cum_seq_lens_q. The only subtle point is that _paged_kv_indptr_buf is in units of pages, while trtllm paged attention APIs traditionally expect cum_seq_lens_kv in tokens; if the trtllm-gen backend actually consumes those trailing args as cum-token lengths, it may need cumsum(seq_lens) instead of raw page indptr. Worth double-checking against the current trtllm kernel contract.

tests/attention/test_batch_prefill_kernels.py (1)

147-157: Good coverage of preallocated LSE path; consider also checking LSE values

Using lse_buffer = torch.empty_like(lse) and rerunning with out=o_buffer, lse=lse_buffer now exercises the buffered LSE write path, which should catch the Hopper regression. To fully validate it, you may also want to assert torch.testing.assert_close(lse, lse_buffer, ...) alongside the existing o vs o_buffer check.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9964cc and 0bbd445.

📒 Files selected for processing (15)
  • csrc/batch_prefill_sm90.cu (1 hunks)
  • csrc/batch_prefill_sm90_customize_config.jinja (1 hunks)
  • csrc/flashinfer_page_binding.cu (0 hunks)
  • csrc/page.cu (0 hunks)
  • flashinfer/page.py (0 hunks)
  • flashinfer/prefill.py (3 hunks)
  • flashinfer/sparse.py (2 hunks)
  • include/flashinfer/attention/hopper/default_params.cuh (1 hunks)
  • include/flashinfer/attention/hopper/prefill_sm90.cuh (1 hunks)
  • include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (8 hunks)
  • include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1 hunks)
  • include/flashinfer/attention/hopper/sparse_mainloop.cuh (8 hunks)
  • include/flashinfer/page.cuh (0 hunks)
  • tests/attention/test_batch_prefill_kernels.py (1 hunks)
  • tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py (0 hunks)
💤 Files with no reviewable changes (5)
  • csrc/flashinfer_page_binding.cu
  • csrc/page.cu
  • flashinfer/page.py
  • tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py
  • include/flashinfer/page.cuh
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/batch_prefill_sm90.cu (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
  • TVM_FFI_ICHECK_EQ (167-171)
  • TVM_FFI_ICHECK_EQ (283-286)
tests/attention/test_batch_prefill_kernels.py (1)
flashinfer/prefill.py (6)
  • run (1924-1936)
  • run (1939-1951)
  • run (1953-2166)
  • run (2768-2778)
  • run (2781-2791)
  • run (2793-2939)
flashinfer/prefill.py (1)
flashinfer/page.py (1)
  • get_seq_lens (176-199)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)

102-108: LGTM: Correct handling of separate K and V strides.

This implementation correctly supports different memory layouts for K and V:

  1. Parameterized design: The load_kv_tile lambda (lines 232-267) accepts stride_n and page_stride as parameters rather than hardcoding them
  2. Separate calls: K and V loads pass their respective strides:
    • K: load_kv_tile(k_base_ptr, k_stride_n, k_page_stride, ...) (line 275)
    • V: load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, ...) (line 298)
  3. Flexible addressing: Line 259 computes offsets using the passed-in parameters

This is the correct pattern for page-based sparse loading and avoids the stride assumption issue present in sparse_mainloop.cuh.

Also applies to: 118-124, 232-267

include/flashinfer/attention/hopper/sparse_mainloop.cuh (1)

110-112: Stride equality is already validated on the host side; v_page_stride is intentionally passed through for API consistency.

The v_page_stride parameter, while unused in the non-quantized sparse_mainloop.cuh kernel, is not a bug. An assertion in csrc/batch_prefill_sm90.cu line 235 validates that K and V page strides are equal at runtime, and the comment in the sparse mainloop (line 281) explicitly documents this assumption. The prefetch_kv_offset lambda correctly reuses the same offset computation for both K and V loads.

The parameter exists for API consistency with the quantized variant (mainloop_sparse_load.cuh), which does use v_page_stride separately. If API unification across quantized and non-quantized paths is intentional, no action is needed.

include/flashinfer/attention/hopper/default_params.cuh (1)

157-160: k_page_stride / v_page_stride fields look consistent

Adding explicit page-stride fields after nnz_qo matches the other Hopper paged params structs and keeps types/ordering coherent with the new sparse mainloop arguments.

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)

337-344: FP8 sparse mainloop argument rewiring looks correct

Switching K/V from get_gmem_layout to explicit {k_stride_n, k_page_stride, v_stride_n, v_page_stride, kv_indices, page_size} matches the updated sparse mainloop API and keeps Q/O layout handling unchanged.

flashinfer/prefill.py (1)

36-36: Importing get_seq_lens is appropriate

This import matches the later use of get_seq_lens in BatchPrefillWithPagedKVCacheWrapper.plan to derive KV sequence lengths from paged metadata.

include/flashinfer/attention/hopper/prefill_sm90.cuh (1)

382-386: Sparse prefill mainloop now correctly receives KV indices and paging metadata

Passing kv_indices, window_left, k_page_stride, v_page_stride, and page_size into SparseCollectiveMainloop::to_underlying_arguments lines up with the new paged sparse mainloop contract and keeps Q/K/V layouts unchanged.

csrc/batch_prefill_sm90_customize_config.jinja (1)

107-111: PagedParams gains explicit K/V page strides in the right place

Adding k_page_stride / v_page_stride after nnz_qo keeps this JIT-generated PagedParams struct aligned with the Hopper default params and with how batch_prefill_sm90.cu now fills these fields from paged_{k,v}_cache.stride(0).

csrc/batch_prefill_sm90.cu (1)

221-238: Page-stride wiring and K/V stride consistency checks make sense

Recording k_page_stride / v_page_stride from stride(0) in both layouts and then asserting that K/V share the same page stride and stride_n is a good guardrail for the sparse paged mainloop; it will surface mis-laid-out KV caches early with clear error messages rather than letting the kernel access mismatched layouts.

Comment on lines 268 to 300
int64_t my_kv_offset[2]; // Rolling buffer: page_idx * page_stride + entry_idx * stride_n

// Group organization based on partition strategy
constexpr int NUM_KV_PER_ITER = decltype(size<1>(tKcK))::value; // e.g., 12
constexpr int KV_STRIDE = CTA_KV / NUM_KV_PER_ITER; // 96/12 = 8
constexpr int NUM_GROUPS = KV_STRIDE; // 8 groups (one per lane)
constexpr int THREADS_PER_GROUP = NUM_COPY_THREADS / NUM_GROUPS; // 128/8 = 16
constexpr int NUM_ITERS_PER_GROUP = NUM_KV_PER_ITER; // 12 iterations per group

int group_id = thread_idx / THREADS_PER_GROUP; // 0-7
int thread_in_group = thread_idx % THREADS_PER_GROUP; // 0-15

// Prefetch: compute page_idx * page_stride + entry_idx * stride_n
// NOTE: Assumes K and V have same strides (asserted on host side)
auto prefetch_kv_offset = [&](int kv_tile_idx, bool use_predicate) {
int kv_base_idx = kv_tile_idx * CTA_KV;
int buf_idx = kv_tile_idx % 2;

int kv_idx_read = kv_base_idx + group_id + thread_in_group * KV_STRIDE;
bool valid_read =
thread_in_group < NUM_ITERS_PER_GROUP && (!use_predicate || kv_idx_read < kv_len);

if (valid_read) {
// Use divmod to find page and offset within page
uint32_t page_iter, entry_idx;
mainloop_params.page_size.divmod(kv_idx_read, page_iter, entry_idx);
IdType page_idx = kv_indices_ptr[page_iter];
// Pre-compute: page_idx * page_stride + entry_idx * stride_n
my_kv_offset[buf_idx] = page_idx * k_page_stride + entry_idx * k_stride_n;
} else {
my_kv_offset[buf_idx] = 0;
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Prefetch logic assumes K and V have identical strides.

The prefetch_kv_offset lambda computes my_kv_offset using only K strides (k_page_stride and k_stride_n on line 296), but this offset is later reused for both K and V loads in load_kv_with_gather. This hardcodes the assumption that K and V have identical memory layouts.

Compare with mainloop_sparse_load.cuh (lines 232-267), which correctly uses separate stride parameters in its load_kv_tile lambda, allowing K and V to have different layouts.

Consider refactoring to either:

  • Option 1: Compute separate offsets for K and V if they can differ
  • Option 2: Use a single set of stride parameters if layouts must be identical
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 268-300,
the prefetch lambda computes my_kv_offset using only K strides but the same
offset is later used for both K and V loads, incorrectly assuming identical K/V
layouts; fix by computing distinct offsets for K and V (or enforce
identical-layout at compile/runtime). Update the lambda to accept/use separate
stride parameters (e.g., k_page_stride/k_stride_n and v_page_stride/v_stride_n)
and write into two rolling buffers (my_kv_offset_k[2] and my_kv_offset_v[2]) so
load_kv_with_gather can use the correct offset for each tensor, or alternatively
add a clear static_assert/runtime check and comment that K and V must share
strides and keep single offset.

Comment on lines 303 to 335
auto load_kv_with_gather = [&](auto&& tXsX, auto&& tXcX, DTypeKV* base_ptr, int kv_tile_idx,
int stage_idx, bool use_predicate) {
using Vec = AlignmentTypeKV;
constexpr int VecSize = sizeof(Vec) / sizeof(DTypeKV);

int kv_base_idx = kv_tile_idx * CTA_KV;
int buf_idx = kv_tile_idx % 2;

auto dst = recast<Vec>(flatten(tXsX(_, _, _, stage_idx)));
auto c = flatten(tXcX(_, _, _, kv_tile_idx));

constexpr unsigned FULL_MASK = 0xffffffff;

// Load using FA3-style shuffle with pre-computed offsets
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst); ++i) {
auto coord = c(VecSize * i);
int kv_offset = get<0>(coord);
int d_idx = get<1>(coord);
int kv_idx = kv_base_idx + kv_offset;
bool guard = !use_predicate || kv_idx < kv_len;

// Shuffle the pre-computed offset (page_idx * page_stride + entry_idx * stride_n)
int src_thread = group_id * THREADS_PER_GROUP + kv_offset / KV_STRIDE;
int64_t base_offset = __shfl_sync(FULL_MASK, my_kv_offset[buf_idx], src_thread);

// Final address: base_ptr + base_offset + d_idx
// where base_offset = page_idx * page_stride + entry_idx * stride_n
Vec const* src_ptr = reinterpret_cast<Vec const*>(base_ptr + base_offset + d_idx);
cutlass::arch::cp_async_zfill<sizeof(Vec), cutlass::arch::CacheOperation::Global>(
&dst(i), src_ptr, guard);
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

load_kv_with_gather reuses K offsets for V loads.

The load_kv_with_gather helper shuffles and reuses my_kv_offset (computed using K strides in prefetch_kv_offset) for both K and V loads:

  • Line 341: load_kv_with_gather(..., K_ptr_base, ...)
  • Line 367: load_kv_with_gather(..., V_ptr_base, ...)

This shuffle-based optimization is effective for performance but requires K and V to have identical page strides and per-token strides. If this constraint is enforced elsewhere, add an assertion or comment clarifying why separate v_page_stride parameters exist but are unused.

For reference, mainloop_sparse_load.cuh avoids this issue by passing stride parameters explicitly to its load_kv_tile helper.

🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 303 to
335, load_kv_with_gather reuses the K offsets (my_kv_offset) for V loads which
is only valid if K and V have identical page and per-token strides; update the
code to either (A) assert at runtime (or static_assert / debug check) that
v_page_stride == k_page_stride and per-token strides match and add a clear
comment explaining why v_page_stride parameter is unused, or (B) change the
caller/implementation so V uses its own computed offsets (compute a separate
my_v_offset in prefetch_v_offset and shuffle that for V loads) so K and V can
have different strides—pick one approach and apply consistently (add the
assertion/comment if you choose A; implement separate offset computation and use
it in the shuffle and cp_async_zfill calls if you choose B).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)

476-591: LGTM!

The new ragged KV dispatch functions are correctly implemented:

  • Uses TMA load for contiguous ragged memory (consistent with single prefill)
  • Proper layout construction and scheduler setup
  • Head dimension dispatch covers all supported values (64, 128, 256)

Minor: The comment on line 497 ("NOTE(Zihao): nnz was useless here, we can just pass 0") reads as a debug/TODO note. Consider removing or rephrasing if the implementation is finalized.

flashinfer/prefill.py (1)

416-472: Consider aligning FP8 detection with tensor dtype check.

The FP8 detection here uses scale_q is not None (line 421), while other places in the codebase use is_float8(q). This could lead to inconsistency if:

  1. FP8 input is provided without scale tensors
  2. Non-FP8 input is accidentally provided with scale tensors

Consider using is_float8(q) for consistency, or add a validation that ensures FP8 inputs always have scale tensors.

-        # Check if FP8 by presence of scale tensors
-        is_fp8 = scale_q is not None
+        # Check if FP8 by tensor dtype
+        is_fp8 = is_float8(q)
+        if is_fp8 and scale_q is None:
+            raise ValueError("FP8 inputs require scale_q, scale_k, scale_v tensors")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7886b7d and 876c386.

📒 Files selected for processing (6)
  • csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja (1 hunks)
  • csrc/batch_prefill_fp8_sm90.cu (3 hunks)
  • flashinfer/prefill.py (21 hunks)
  • include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (8 hunks)
  • include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2 hunks)
  • tests/attention/test_hopper_fp8_attention.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
  • get_seq_lens (176-199)
flashinfer/utils.py (3)
  • canonicalize_torch_dtype (240-248)
  • check_shape_dtype_device (519-537)
  • is_float8 (157-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (17)
tests/attention/test_hopper_fp8_attention.py (3)

186-280: LGTM!

The test function is well-structured, following the established pattern for FP8 testing. It correctly:

  • Creates variable-length sequences for batch prefill
  • Generates FP16 reference output
  • Quantizes inputs to FP8
  • Compares MSE between FP16 and FP8 paths

283-403: LGTM!

The paged KV cache test is correctly implemented:

  • Proper page allocation and indptr/indices construction
  • Appropriate reshape-quantize-reshape pattern for paged KV tensors
  • Consistent with the ragged test structure

406-426: LGTM!

The __main__ block updates provide convenient local test execution with a reasonable subset of parameters for quick validation.

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)

337-344: LGTM!

The parameter changes correctly pass stride and page-size information directly to the sparse mainloop, aligning with the PR's objective of moving page offset calculation into the kernel.

csrc/batch_prefill_fp8_sm90.cu (2)

86-173: LGTM!

The BatchPrefillWithRaggedKVCacheSM90Run implementation is well-structured:

  • Proper plan info initialization and LSE validation
  • Correct layout-aware stride handling for NHD/HND
  • Appropriate static assertions for FP8 constraints
  • Consistent error handling pattern

231-243: LGTM!

The page stride handling is correct. Using stride(0) consistently retrieves the stride between pages regardless of the internal layout (NHD or HND), which is the intended behavior for sparse paged KV cache addressing.

flashinfer/prefill.py (5)

1566-1567: LGTM!

The o_data_type parameter addition is well-implemented with proper canonicalization and caching for use in the run method.


2092-2102: LGTM!

The output allocation correctly uses the cached output data type with a safe fallback to q.dtype for backward compatibility.


2950-2959: LGTM!

The FP8 handling correctly bypasses the FP16 conversion for FA3 backend while maintaining backward compatibility with FA2 backend (which still shows a deprecation warning and converts to FP16).


3001-3003: LGTM!

The FP8 scale tensor extension follows the established pattern from the paged path.


2170-2189: I'll verify the FP8 scale tensor extraction from *args by examining how callers pass FP8 scale tensors to the run() method and checking the documentation.
<function_calls>

#!/bin/bash

Search for calls to .run() with scale arguments to understand the convention

rg -n --type=py ".run(" flashinfer/ tests/ -A 3 -B 2 | grep -i "scale|fp8" -A 5 -B 5 | head -60


</function_calls>

Let me search more specifically for the run method definition and how it's called with FP8 scales:
<function_calls>

#!/bin/bash

Look at the run() method signature and docstring in prefill.py

rg -n "def run(" flashinfer/prefill.py -A 20 | head -80


</function_calls>

<function_calls>

#!/bin/bash

Search for FP8 scale passing in test files and examples

rg -n "fp8.*scale|scale.*fp8" flashinfer/ tests/ --type=py -i -B 3 -A 3 | head -100


</function_calls>

<function_calls>

#!/bin/bash

Look for is_float8 checks and how they're used with scale tensors

rg -n "is_float8" flashinfer/ --type=py -B 5 -A 10 | head -150


</function_calls>

<function_calls>

#!/bin/bash

Find the run method definition around line 2170

sed -n '2100,2250p' flashinfer/prefill.py | cat -n


</function_calls>

<function_calls>

#!/bin/bash

Find the class that contains this run method and its docstring

sed -n '1800,2200p' flashinfer/prefill.py | tail -400 | head -300 | cat -n


</function_calls>

<function_calls>

#!/bin/bash

Look for the method definition that contains line 2170

rg -n "def run" flashinfer/prefill.py --type=py -B 2 | grep -E "^\d+.*def run"


</function_calls>

<function_calls>

#!/bin/bash

Get context around line 2170 to find which method it belongs to

head -2200 flashinfer/prefill.py | tail -300 | cat -n


</function_calls>

<function_calls>

#!/bin/bash

Search for the exact pattern: fp8_scale_q = args pattern

rg -n "fp8_scale_q.*args|args.*fp8_scale" flashinfer/ --type=py -B 10 -A 15


</function_calls>

<function_calls>

#!/bin/bash

Let me get the actual content around line 2170-2189 with broader context

sed -n '2100,2250p' flashinfer/prefill.py | cat -n


</function_calls>

include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (6)

102-111: LGTM! Clear page-based addressing structure.

The new stride and page fields provide a clean interface for page-based K/V tensor addressing, replacing the previous layout-based approach.


118-127: LGTM! Efficient use of fast division.

Using uint_fastdiv for page_size enables efficient divmod operations in the hot path.


134-137: LGTM! Parameter forwarding is correct.

All new stride and page parameters are properly forwarded from Arguments to Params.


212-231: LGTM! Clean setup for page-based loading.

The coordinate tensor partitioning and parameter extraction properly prepare for the manual K/V loading path.


273-372: LGTM! Proper tile loading sequence and synchronization.

The tile loading pattern correctly applies predication only to the last tile while intermediate tiles load without bounds checking. Pipeline synchronization, V transpose coordination, and barrier usage are all properly structured.


232-266: I encountered a repository clone failure and cannot access the codebase to verify the bounds checking concern. However, I can provide a rewritten review comment based on the code snippet provided and sound software engineering analysis:

Verify bounds check for page table access before accessing kv_indices_ptr.

The lambda correctly implements page-based addressing for K/V tiles, but the code at line 257 lacks validation that page_iter is within the bounds of kv_indices_ptr before array access. When page_iter is computed via divmod(kv_idx, page_size), the result could potentially exceed the allocated size of the page table if:

  • The page table was sized based on an incorrect upper bound for KV entries
  • Concurrent modifications affect the array size
  • Off-by-one errors exist in the page table allocation logic

To resolve this:

  1. Add an assertion or bounds check: CUTE_ASSERT(page_iter < num_pages) before line 257, or
  2. Verify that the page table allocation logic guarantees sufficient capacity for all possible page_iter values derived from valid kv_idx values
  3. Document the invariant that kv_indices_ptr must accommodate ceil_div(kv_len, page_size) entries

Regarding the guard condition at line 252 (!use_predicate || kv_idx < kv_len): This appears correct for intermediate vs. final tiles, but validate that kv_len accurately reflects the maximum valid index.

}
};

int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove or utilize the unused variable.

The variable valid_last_kv_tile_size is computed but never used in the subsequent code. If the intention is to optimize the guard condition in load_kv_tile (as suggested in the past review comment), this value should be passed to the lambda. Otherwise, this line should be removed.

Apply this diff if the variable is not needed:

-    int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
-

Or, if you want to use it for optimized bounds checking, update the load_kv_tile signature to accept it:

-    auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
-                            int tile_idx, int pipe_idx, bool use_predicate) {
+    auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
+                            int tile_idx, int pipe_idx, int valid_tile_size) {

And update the guard condition accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh
around line 268, the local variable valid_last_kv_tile_size is computed but
never used; either remove this line if not needed, or update load_kv_tile to
accept an extra parameter for the valid tile size and pass
valid_last_kv_tile_size into the lambda so it can replace the existing generic
guard logic. If choosing the second option, change the load_kv_tile signature to
take an int valid_tile_size, update all call sites, and use that value inside
the lambda for optimized bounds checking; otherwise simply delete the unused
variable declaration.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)

318-425: Critical: V loads reuse K's prefetch offsets with potentially different strides.

The V loading at lines 345, 363, and 390 reuses prefetch offsets computed for K (with k_stride_n and k_page_stride), but V should use v_stride_n and v_page_stride. This is evident from line 411, which explicitly prefetches V with v_stride_n and v_page_stride.

If K and V have different strides or page strides, V will be loaded from incorrect addresses, causing data corruption.

The API explicitly provides separate stride parameters for K and V (Arguments and Params structs), suggesting they can differ. Either:

  1. Add prefetch calls for V before each V load (lines 345, 363, 390) using v_stride_n and v_page_stride, OR
  2. Document and assert that k_stride_n == v_stride_n and k_page_stride == v_page_stride must hold

Apply this pattern to fix the V loads:

     if (kv_tile_idx == swa_begin_kv_tile_idx) {
-      // first tile is the last tile, reuse kv_tile_idx prefetch for V
+      // first tile is the last tile, prefetch for V
+      prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true);
       pipeline_v.producer_acquire(smem_pipe_write);
       load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);
     } else {
       // load second last k-tile and last v-tile
       // Prefetch for next K tile (kv_tile_idx - 1)
       prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false);
 
-      // Load V using prefetch from last K load (kv_tile_idx)
+      // Prefetch and load V for kv_tile_idx
+      prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true);
       pipeline_v.producer_acquire(smem_pipe_write);
       load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);
       for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) {
         // Prefetch for next K tile
         prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false);
 
-        // Load V using prefetch from previous K prefetch
+        // Prefetch and load V for kv_tile_idx
+        prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, false);
         pipeline_v.producer_acquire(smem_pipe_write);
         load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), false);
♻️ Duplicate comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)

316-316: Remove the unused variable.

As noted in the previous review, valid_last_kv_tile_size is computed but never used.

🧹 Nitpick comments (4)
benchmarks/bench_hopper_fp8_attention.py (2)

216-216: Document or validate page_size divisibility assumption.

Line 216 assumes seq_len is perfectly divisible by page_size. While the current test cases satisfy this (seq_len ∈ {1024, 2048, 4096, 8192} with page_size=16), the function might be called with other parameters in the future.

Consider adding a validation check:

+    assert seq_len % page_size == 0, f"seq_len ({seq_len}) must be divisible by page_size ({page_size})"
     num_pages = batch_size * seq_len // page_size

250-251: Consider making workspace buffer size configurable.

The 256MB workspace buffer is hardcoded for both FP16 and FP8 wrappers. While sufficient for current benchmark sizes, this might be inadequate for larger workloads or future test expansions.

Consider either:

  1. Making workspace size a parameter with a reasonable default
  2. Adding a comment documenting the size assumption
  3. Having the wrappers handle workspace allocation internally if supported

This is a minor point since the current sizes work for the benchmarks being run.

Also applies to: 268-269

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)

331-344: Paged KV mainloop param wiring looks consistent

The new argument list (k/v strides, page stride, kv_indices, page_size) lines up with a paged/sparse K/V mainloop and matches the scheduler/block_coord usage in this kernel. From this file’s perspective the wiring looks correct; no blocking issues.

If Params::page_size is not already a 32‑bit type, consider documenting or static‑asserting the expected range to make the uint32_t cast here self‑evident to future readers.


477-550: Ragged KV kernel‑traits dispatch wiring looks correct; stale comment

The ragged‑KV kernel‑traits dispatch correctly switches to FP8CollectiveMainloop and reuses the BatchPrefill schedulers/arguments in the same way as the paged path, with Q/K/V layouts built via get_gmem_layout, so the host→device params plumbing looks coherent.

The comment on Line 499 saying “nnz was useless here, we can just pass 0” now contradicts the actual params.nnz_kv argument; consider updating or removing this note to avoid confusion about whether the first dimension of the K/V layout is semantically meaningful.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 876c386 and f89ae64.

📒 Files selected for processing (3)
  • benchmarks/bench_hopper_fp8_attention.py (4 hunks)
  • include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (7 hunks)
  • include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_hopper_fp8_attention.py (3)
flashinfer/testing/utils.py (2)
  • bench_gpu_time (985-1046)
  • attention_tflops_per_sec_with_actual_seq_lens (421-454)
benchmarks/bench_block_sparse_attention.py (1)
  • flops (125-134)
benchmarks/bench_hopper_attention.py (3)
  • flops (46-55)
  • flops (107-116)
  • flops (187-196)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
benchmarks/bench_hopper_fp8_attention.py (7)

27-38: LGTM: Correct per-head symmetric quantization implementation.

The quantization logic correctly handles both FP8 formats with appropriate ranges, computes per-head scales by taking max over dimensions (0, 2), and includes defensive clamping to prevent division by zero.


41-108: LGTM: Well-structured FP8 single prefill benchmark.

The benchmark correctly creates FP16 baseline tensors, quantizes them to FP8 with per-head scales, measures both paths using median GPU time, and reports meaningful performance metrics with speedup calculations.


111-201: LGTM: Correct batch ragged prefill benchmark implementation.

The ragged batch benchmark properly constructs indptr arrays for batch boundaries, configures wrappers with appropriate data types, and correctly passes quantization scales to the FP8 execution path.


233-238: LGTM: Correct paged KV quantization strategy.

Flattening the paged KV cache for quantization and then reshaping back is the right approach to maintain per-head quantization semantics across all pages while preserving the paged memory layout.


240-247: LGTM: Correct indptr and page table setup.

The indptr arrays and page indices are correctly constructed:

  • qo_indptr marks query batch boundaries (every seq_len tokens)
  • kv_indptr marks page batch boundaries (every seq_len // page_size pages)
  • kv_indices provides sequential page mapping
  • last_page_len assumes full pages, which is appropriate for uniform benchmark workloads

330-336: Clarify status of skipped single prefill benchmarks.

The single prefill benchmarks are commented out due to "compilation issues." Given the PR objectives mention fixing a failing Hopper unittest, is this related?

Please clarify:

  1. Are these compilation issues expected to be resolved in this PR or a follow-up?
  2. Should this be tracked with a TODO or issue reference?
  3. Is this related to the unittest fixes mentioned in the PR description?

342-356: LGTM: Comprehensive benchmark coverage.

The test configurations provide good coverage across different:

  • Head dimensions (128, 256)
  • Batch sizes (16-128)
  • Sequence lengths (1024-8192)
  • Both ragged and paged KV cache layouts

The parameter combinations maintain roughly constant total token counts, which is sensible for comparing performance across configurations.

include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (7)

102-108: LGTM: Page-based KV cache parameters added.

The addition of separate stride and page_stride parameters for K and V tensors, along with page_size, correctly supports the refactored page-based KV loading scheme.


118-124: LGTM: Efficient fastdiv used for page_size.

Using uint_fastdiv for page_size enables efficient divmod operations in the kernel hot path.


134-137: LGTM: Parameter forwarding is correct.

All new page-based parameters are correctly forwarded from Arguments to Params.


212-220: LGTM: Manual K/V loading setup is complete.

All required parameters for page-based K/V loading are correctly extracted and prepared.


232-259: LGTM: Well-documented thread organization for FA3-style prefetch.

The rolling buffer prefetch scheme and detailed thread organization comments are helpful for understanding this complex optimization. The NUM_KV_PER_ITER calculations appear correct.


260-280: LGTM: Page-based offset prefetch is correctly implemented.

The divmod-based page addressing and rolling buffer management are correctly implemented. The offset computation properly combines page-level and entry-level addressing.


282-314: LGTM: Shuffle-based offset loading is correctly implemented.

The shuffle-based offset sharing and cp_async_zfill with guard correctly implement the FA3-style optimized loading pattern.

include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)

462-476: CTA_KV=64 for HEAD_DIM=256 paged path seems reasonable; please benchmark

Reducing CTA_KV from 128→64 for the sparse paged path (with the accompanying comment about 64×64 FP8 transpose minimum) is a plausible trade‑off to cut page‑table lookups; launch shape and error handling remain consistent with other HEAD_DIM branches.

Please sanity‑check perf/occupancy for HEAD_DIM=256 on Hopper (especially long‑seq FA3 workloads) to ensure this smaller CTA_KV doesn’t introduce regressions compared to the previous configuration.


552-592: New BatchFP8PrefillWithRaggedKVCacheDispatched entrypoint matches existing patterns

This wrapper mirrors the single‑batch FP8 dispatch: HEAD_DIM specializations, USE_TMA_LOAD_KV=true for ragged K/V, and the same error‑reporting pattern as the paged variant. The trait choices (CTA_Q/CTA_KV/NUM_STAGES) are consistent with the non‑ragged FP8 paths.

Once the ragged‑KV tests are in place, it’d be good to run them for all HEAD_DIM (64/128/256) with large nnz_qo/nnz_kv configurations comparable to issue #1647 to confirm this new batch entrypoint behaves as expected on Hopper.

@yzh119 yzh119 requested a review from yongwww as a code owner November 27, 2025 08:17
@yzh119
Copy link
Collaborator Author

yzh119 commented Nov 27, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #39439259: canceled

enable_pdl = device_support_pdl(q.device)
k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)
_check_cached_qkv_data_type(
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add docs for q_scale

fp8_scale_q = None
fp8_scale_k = None
fp8_scale_v = None
if is_float8(q) and len(args) >= 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of making these are implicit positional args. Instead, I think we should document it clearly and make them explicit.

For example, we can make k_scale : Optional[torch.Tensor, float] or have separate keyword args like k_scale_device. Also, we should document clearly that these scale tensors should be per-head. Otherwise, illegal mem access will occur.

None, # scale_q, not supported yet
None, # scale_k
None, # scale_v
fp8_scale_q,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the FA3 FP8 kernels, I suggest that we change scale_q/scale_k/scale_v to maybe_scale_q/maybe_scale_k/maybe_scale_v

See a8d9e6a as an example about how to do this.

o_data_type=o_dtype,
causal=causal,
)
o_fp8 = wrapper_fp8.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have tests that cover both types of scaling factors:

  1. host-side per-tensor scales
  2. device-side per-head scale tensors

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (2)

227-234: Verify head offset calculation is correct for all KV cache layouts.

A past reviewer flagged the base pointer calculation as an accuracy bug. The current implementation uses kv_head_idx * k_stride_h for the head offset. Please confirm that k_stride_h correctly represents the stride between heads in elements for all supported KV cache layouts (NHD vs HND), and that this matches the stride values passed from the Python/dispatch layer.

Run the following script to verify how stride_h is computed and passed to this kernel:

#!/bin/bash
# Search for where k_stride_h is computed and passed to FP8SparseCollectiveMainloop
rg -n -C3 'k_stride_h|v_stride_h' --type=cpp --type=cu

331-331: Remove unused variable.

The variable valid_last_kv_tile_size is computed but never used in the subsequent code. Remove this line to eliminate dead code.

Apply this diff:

-    int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
-
🧹 Nitpick comments (4)
flashinfer/prefill.py (2)

2100-2108: Inconsistent _cached_o_data_type access pattern.

Using getattr(self, "_cached_o_data_type", None) suggests _cached_o_data_type may not always be set, but plan() always sets it (lines 1711-1713, 2645-2647). Consider using self._cached_o_data_type directly for consistency with the rest of the codebase, or document why the defensive access is needed.

-            # Use cached output data type if available (for FP8 attention with FP16 output)
-            out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
+            # Use cached output data type (for FP8 attention with FP16 output)
+            out_dtype = self._cached_o_data_type
             out = torch.empty(
                 q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
             )
         else:
-            out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
+            out_dtype = self._cached_o_data_type

2177-2196: FP8 scale tensor extraction from *args is fragile.

The code assumes args[0:3] are scale tensors when is_float8(q) is true, but there's no validation that these are actually tensor types or have the expected shapes. A mismatch could cause silent incorrect behavior or cryptic errors.

Consider adding explicit keyword arguments for FP8 scales in the paged run() method signature (similar to how q_scale, k_scale, v_scale are already present) instead of extracting from *args. This would make the API explicit and type-safe.

csrc/batch_prefill_fp8_sm90.cu (1)

103-104: Inconsistent device guard usage.

BatchPrefillWithRaggedKVCacheSM90Run uses cudaSetDevice() directly (line 103), while BatchPrefillWithPagedKVCacheSM90Run uses ffi::CUDADeviceGuard (line 206). The RAII pattern with CUDADeviceGuard is safer as it restores the device on scope exit.

-  cudaSetDevice(float_workspace_buffer.device().device_id);
+  ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
tests/attention/test_hopper_fp8_attention.py (1)

347-388: Remove unused head_dim parameter.

The head_dim parameter is not used in the function body since the dimension information is already contained in the shape tuple. Consider removing this parameter to simplify the function signature.

Apply this diff to the function signature and update all call sites:

 def create_per_head_varying_kv(
     shape: Tuple[int, ...],
     num_heads: int,
-    head_dim: int,
     dtype: torch.dtype,
     device: str,
 ) -> torch.Tensor:

Then update call sites (lines 453, 460, 595, 602, 728, 735) by removing the head_dim argument.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 31f7d31 and 594a467.

📒 Files selected for processing (10)
  • csrc/batch_prefill_fp8_sm90.cu (3 hunks)
  • csrc/batch_prefill_sm90.cu (1 hunks)
  • csrc/page.cu (0 hunks)
  • flashinfer/page.py (0 hunks)
  • flashinfer/prefill.py (21 hunks)
  • flashinfer/sparse.py (2 hunks)
  • include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (7 hunks)
  • include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2 hunks)
  • tests/attention/test_hopper.py (3 hunks)
  • tests/attention/test_hopper_fp8_attention.py (3 hunks)
💤 Files with no reviewable changes (2)
  • csrc/page.cu
  • flashinfer/page.py
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
  • get_seq_lens (178-201)
flashinfer/utils.py (2)
  • canonicalize_torch_dtype (241-249)
  • is_float8 (158-159)
tests/attention/test_hopper_fp8_attention.py (2)
flashinfer/utils.py (1)
  • is_sm90a_supported (526-528)
flashinfer/prefill.py (9)
  • plan (1549-1937)
  • plan (2507-2804)
  • run (1968-1980)
  • run (1983-1995)
  • run (1998-2222)
  • run (2834-2844)
  • run (2847-2857)
  • run (2860-3017)
  • BatchPrefillWithPagedKVCacheWrapper (1260-2254)
🪛 Ruff (0.14.7)
tests/attention/test_hopper_fp8_attention.py

19-19: Avoid specifying long messages outside the exception class

(TRY003)


350-350: Unused function argument: head_dim

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (15)
flashinfer/prefill.py (2)

418-423: FP8 scale parameters added correctly.

The optional scale tensors for FP8 quantization are properly added with None defaults, maintaining backward compatibility.


2960-2969: Approve FA3 FP8 bypass of conversion.

Correctly skips the FP8→FP16 conversion when the FA3 backend handles FP8 natively.

tests/attention/test_hopper.py (3)

210-210: Good addition of page_size=16 test coverage.

This extends test coverage to non-unit page sizes, which is important given the page-based addressing changes in this PR.


270-270: Correct removal of padding from kv_indices.

The previous version added 256 extra indices as padding, which was unnecessary. The new version generates exactly batch_size * num_pages_per_request indices, matching the actual number of pages allocated.


437-437: Test data adjustment for multi-item scoring.

The padding distribution change (* 17 and * 5 instead of previous values) aligns with the test case's token_pos_in_items_len=97 and the actual data lengths.

csrc/batch_prefill_sm90.cu (1)

221-238: Page stride fields and validation correctly implemented.

The page stride extraction and runtime validation are well-implemented:

  • Correctly extracts stride(0) for page-level stride in both NHD and HND layouts
  • Runtime checks ensure K and V have matching strides, which is required for the sparse mainloop optimization
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (4)

337-346: Page stride parameters correctly wired to CollectiveMainloop.

The mainloop arguments now include k_page_stride, v_page_stride, kv_indices, and page_size, enabling sparse paged KV cache support.


464-477: CTA_KV optimization for HEAD_DIM=256 is well-documented.

The reduction from CTA_KV=128 to CTA_KV=64 for sparse paged loading with HEAD_DIM=256 is explained in the comment. The constraint that FP8 transpose requires minimum 64x64 blocks justifies not reducing further.


479-552: Ragged KV cache function correctly mirrors paged structure.

The new BatchFP8PrefillWithRaggedKVCacheKernelTraitsDispatched follows the same pattern as the paged version but uses FP8CollectiveMainloop instead of FP8SparseCollectiveMainloop, and uses standard gmem layouts instead of page-based addressing.


554-594: Ragged dispatch function uses TMA for all head dimensions.

The ragged path uses USE_TMA_LOAD_KV=true for all head dimensions (64, 128, 256), while the paged path uses USE_TMA_LOAD_KV=false. This is correct since ragged KV has contiguous memory that can benefit from TMA, whereas paged KV has non-contiguous pages.

csrc/batch_prefill_fp8_sm90.cu (2)

86-173: Ragged KV cache SM90 run implementation looks correct.

The function properly:

  • Validates LSE tensor dimensions
  • Populates RaggedParams with correct strides for both NHD and HND layouts
  • Uses static assertions for head dimension and dtype consistency
  • Dispatches to the appropriate kernel based on scheduler configuration

231-242: Page stride fields correctly added for FP8 paged path.

Consistent with the changes in batch_prefill_sm90.cu, the FP8 path now also stores page strides for sparse paged KV cache support.

include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)

246-329: LGTM: Well-designed FA3-style prefetch optimization.

The prefetch offset calculation and load implementation with shuffle-based offset sharing is well-designed. The detailed comments explain the group organization, thread distribution, and double-buffering strategy clearly. The use of divmod for efficient page table lookup and shuffle for offset sharing across threads is appropriate for FA3.

tests/attention/test_hopper_fp8_attention.py (2)

250-811: Excellent test coverage for FP8 attention paths.

The new tests comprehensively cover FP8 quantization scenarios:

  • Ragged KV cache (test_batch_prefill_ragged)
  • Paged KV cache (test_batch_prefill_paged)
  • GQA support (test_batch_prefill_paged_gqa) - addresses past review comment
  • Per-tensor vs per-head scale types (test_batch_prefill_paged_scale_types) - addresses past review comment

The use of create_per_head_varying_kv to surface head-mapping bugs is a clever testing strategy.


12-19: FP8 min/max values are correct per NVIDIA specification.

The hard-coded values match the NVIDIA FP8 format specification: float8_e4m3fn max finite normal is ±448, and float8_e5m2 max finite normal is ±57,344. No changes needed.

Comment on lines +3011 to +3013
# For FP8, append scale tensors
if is_float8(q):
run_args.extend(list(args)) # scale_q, scale_k, scale_v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

FP8 scale tensors appended without validation.

When is_float8(q) is true, args is appended directly. There's no check that args actually contains exactly 3 scale tensors. If the user forgets to pass scales, this will silently pass incorrect data to the kernel.

Consider adding validation:

             # For FP8, append scale tensors
             if is_float8(q):
+                if len(args) < 3:
+                    raise ValueError(
+                        "FP8 attention requires scale_q, scale_k, scale_v tensors to be passed"
+                    )
                 run_args.extend(list(args))  # scale_q, scale_k, scale_v
🤖 Prompt for AI Agents
In flashinfer/prefill.py around lines 3011 to 3013, the code appends args when
is_float8(q) without validating that args contains exactly the three FP8 scale
tensors; add a guard that checks len(args) == 3 and that each element is a
tensor of expected dtype/shape (or at least a tensor-like object), and raise a
clear ValueError/TypeError if the check fails so the kernel never receives
missing/malformed scale arguments; update the error message to indicate which
scale(s) are missing or invalid.

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 3, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been updated with latest changes, and the CI pipeline #39546502 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #39546502: canceled

@yzh119 yzh119 requested a review from Anerudhan as a code owner December 5, 2025 10:42
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
flashinfer/prefill.py (2)

1702-1704: Documentation improvement needed.

The docstring mentions "For FP8 inputs, this should typically be set to torch.float16" but bfloat16 is also a valid option. Consider updating to: "For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16."

This was also noted in a past review comment.


2222-2229: Add validation for FP8 scale tensors.

A past review flagged that when is_float8(q) is true, the code should validate that args contains exactly 3 scale tensors. Currently, if len(args) < 3, the scales remain None and are silently passed to the kernel, which may cause incorrect results or crashes.

Consider adding validation as suggested:

             # Extract FP8 scale tensors from *args if q is FP8
             fp8_scale_q = None
             fp8_scale_k = None
             fp8_scale_v = None
             if is_float8(q) and len(args) >= 3:
+                if len(args) < 3:
+                    raise ValueError(
+                        "FP8 attention requires scale_q, scale_k, scale_v tensors to be passed"
+                    )
                 fp8_scale_q = args[0]
                 fp8_scale_k = args[1]
                 fp8_scale_v = args[2]
🧹 Nitpick comments (3)
include/flashinfer/attention/hopper/variants.cuh (1)

28-35: Consider consolidating SFINAE trait definitions.

Both this file and variant_helper.cuh define DEFINE_HAS_MEMBER traits. While this file adds traits for maybe_scale_v/q/k and scale_v/q/k_scalar, and variant_helper.cuh defines v_scale, having trait definitions split across files may lead to maintenance overhead.

Consider centralizing all SFINAE trait definitions in a single header (e.g., utils.cuh or a dedicated traits.cuh) to improve discoverability and reduce duplication risk.

tests/attention/test_hopper_fp8_attention.py (2)

347-370: Unused head_dim parameter.

The head_dim parameter is declared but never used in the function body. Consider removing it or using it for validation.

 def create_per_head_varying_kv(
     shape: Tuple[int, ...],
     num_heads: int,
-    head_dim: int,
     dtype: torch.dtype,
     device: str,
 ) -> torch.Tensor:

Note: This would require updating all call sites (lines 455-468, 462-468, 598-604, 605-611, 731-737, 738-744).


520-520: Remove debug print statement.

This print statement appears to be leftover debugging code that should be removed for cleaner test output.

     o_fp8 = wrapper_fp8.run(q_fp8, (paged_k_fp8, paged_v_fp8), s_q, s_k, s_v)
-    print(o_ref, o_fp8)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 594a467 and 37b20a8.

📒 Files selected for processing (7)
  • csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja (1 hunks)
  • flashinfer/jit/attention/modules.py (2 hunks)
  • flashinfer/prefill.py (29 hunks)
  • include/flashinfer/attention/hopper/mainloop_mma.cuh (1 hunks)
  • include/flashinfer/attention/hopper/variants.cuh (3 hunks)
  • include/flashinfer/attention/variant_helper.cuh (2 hunks)
  • tests/attention/test_hopper_fp8_attention.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
  • get_seq_lens (178-201)
flashinfer/utils.py (2)
  • canonicalize_torch_dtype (241-249)
  • is_float8 (158-159)
tests/attention/test_hopper_fp8_attention.py (1)
flashinfer/utils.py (1)
  • is_sm90a_supported (526-528)
🪛 Ruff (0.14.7)
tests/attention/test_hopper_fp8_attention.py

19-19: Avoid specifying long messages outside the exception class

(TRY003)


350-350: Unused function argument: head_dim

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (17)
include/flashinfer/attention/variant_helper.cuh (2)

23-28: LGTM! Clean SFINAE-based trait detection for optional v_scale.

The pattern using DEFINE_HAS_MEMBER and the get_v_scale helper enables backward-compatible optional scaling.


85-99: LGTM! Output scaling logic is correct.

The get_v_scale helper properly defaults to 1.0f when v_scale is not present in Params, maintaining backward compatibility while enabling FP8 scaling paths.

flashinfer/jit/attention/modules.py (2)

531-554: LGTM! Parameter lists correctly expanded for FP8 scale handling.

The non-FP8 path adds maybe_scale_v/scale_v_scalar, while the FP8 path adds all three scale pairs (maybe_scale_q/k/v and scale_q/k/v_scalar). The tensor/scalar counts and dtypes are correctly aligned.


1016-1048: LGTM! Batch prefill module parameters mirror single prefill updates.

The FA3 batch prefill path correctly mirrors the single prefill parameter expansions for both non-FP8 and FP8 paths.

csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja (1)

1-15: LGTM! Template instantiation structure is correct.

The template properly generates explicit instantiations for both SAME_SCHEDULER_FOR_ALL_HEADS variants. The namespace closing syntax has been corrected per the previous review.

include/flashinfer/attention/hopper/variants.cuh (3)

37-40: LGTM! Null-safe scale accessor.

The get_scale helper correctly handles the case where tensor_ptr is nullptr by falling back to scalar_val.


72-91: LGTM! StandardAttention correctly initializes scale_pv.

The structured binding extracts block coordinates, and scale_pv is initialized using the get_v_scale helper with proper fallback behavior.


117-131: LGTM! StandardFP8Attention scale computation is correct.

The FP8 attention variant properly computes:

  • p_scale from the FP8 type's max value
  • scale_pv = v_scale / p_scale for PV dequantization
  • sm_scale_with_qk_log2 incorporating q_scale, k_scale, and sm_scale

This correctly handles the FP8 quantization/dequantization flow.

include/flashinfer/attention/hopper/mainloop_mma.cuh (1)

319-319: Thread scale_pv into finalization for FP8 dequantization.

This change passes variant.scale_pv to the finalize call, enabling fused PV dequantization as part of online softmax finalization. Ensure the AttentionUpdater::finalize method signature accepts this parameter.

tests/attention/test_hopper_fp8_attention.py (3)

250-344: LGTM!

The ragged KV cache FP8 test is well-structured with clear setup, reference computation, and MSE validation. The variable length sequences provide good coverage.


528-545: LGTM!

This test addresses the previously requested GQA coverage. The parameterization with different head ratios (32:8, 16:4, 8:2) provides good coverage for head mapping logic verification.


668-691: LGTM!

This test addresses the previously requested coverage for both per-tensor and per-head scale types. The test correctly broadcasts per-tensor scales to per-head format before passing to the kernel.

flashinfer/prefill.py (5)

69-86: LGTM!

The _split_scale_param helper cleanly handles the three cases (None, tensor, scalar) and provides a consistent interface for FP8 scale handling throughout the module.


310-349: LGTM!

The FP8 scale handling correctly uses _split_scale_param to decompose scales into tensor and scalar components, supporting both per-head tensor scales and scalar scales.


446-451: Consider validating FP8 scale tensor consistency.

The FP8 detection relies solely on scale_q is not None. If a user passes scale_q but omits scale_k or scale_v, _split_scale_param will return (None, 1.0) which may silently produce incorrect results instead of raising an error.

Consider adding validation:

         # Check if FP8 by presence of scale tensors
         is_fp8 = scale_q is not None
+        if is_fp8 and (scale_k is None or scale_v is None):
+            raise ValueError(
+                "FP8 attention requires all scale tensors (scale_q, scale_k, scale_v)"
+            )

2144-2154: LGTM!

The output dtype handling correctly uses _cached_o_data_type when available, with a safe fallback to q.dtype for backward compatibility.


3056-3058: Add validation for FP8 scale tensors in ragged run.

Similar to the paged run path, this code appends args as FP8 scales without validating that the required 3 scale tensors are present. If is_float8(q) but args is empty or has fewer than 3 elements, the kernel may receive incorrect arguments.

Consider adding validation:

             # For FP8, append scale tensors
             if is_float8(q):
+                if len(args) < 3:
+                    raise ValueError(
+                        "FP8 attention requires scale_q, scale_k, scale_v tensors to be passed"
+                    )
                 run_args.extend(list(args))  # scale_q, scale_k, scale_v

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 5, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been updated with latest changes, and the CI pipeline #39670861 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #39670861: canceled

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
include/flashinfer/attention/hopper/variants.cuh (1)

21-82: Scale helper and SFINAE layer is sound; be aware of the “all-or-nothing” detection

The DEFINE_HAS_MEMBER traits plus get_scale / get_{v,q,k}_scale helpers provide a clean way to support optional per-head scales while defaulting to 1.0f when scale fields are absent, which keeps non-FP8 / legacy paths unchanged. Note that each getter requires both maybe_scale_* and scale_*_scalar to exist; if a future AdditionalParams defines only one of these, it will silently fall back to 1.0f. That’s reasonable, but worth keeping in mind when extending AdditionalParams types.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 37b20a8 and d06ea1a.

📒 Files selected for processing (3)
  • flashinfer/jit/attention/variants.py (1 hunks)
  • include/flashinfer/attention/hopper/mainloop_mma.cuh (2 hunks)
  • include/flashinfer/attention/hopper/variants.cuh (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/jit/attention/variants.py (1)

145-159: FA3 AttentionSink: scale_pv initialization looks consistent

Destructuring block_coord to expose kv_head_idx and initializing scale_pv via get_v_scale(params.additional_params, kv_head_idx) matches the new Hopper variant pattern and keeps behavior no-op when v-scales are absent (fallback 1.0f). Within this snippet scale_pv is only stored; assuming downstream code reads it (e.g., via a helper similar to get_variant_scale_pv), this is a clean extension with no behavior change for setups without per-head v-scales.

include/flashinfer/attention/hopper/mainloop_mma.cuh (1)

15-16: Plumbing pv-scale into finalize is correct; ensure all updaters support the new signature

Including "variants.cuh" here and calling attention_updater.finalize(tSrS, get_variant_scale_pv(variant)) correctly threads variant-specific scale_pv into the online softmax path while preserving old behavior for variants without a scale_pv field (fallback 1.0f). Please double-check that every AttentionUpdater used with this mma_f16 now implements a compatible finalize(Tensor0&, float) (or has a defaulted second parameter) so no template instantiation breaks for non-sink / non-FP8 paths.

Also applies to: 321-321

include/flashinfer/attention/hopper/variants.cuh (1)

83-165: Variant scale_pv wiring aligns with FP8 and non-FP8 semantics; confirm no remaining ODequantize users

The additions of scale_pv to StandardAttention and LogitsSoftCap (initialized via get_v_scale) and to StandardFP8Attention (as v_scale / p_scale) line up with the new get_variant_scale_pv(variant) + finalize(..., pv_scale) path and keep behavior unchanged when no scales are provided (helpers return 1.0f). For StandardFP8Attention, using q_scale * k_scale to fold Q/K dequantization into sm_scale_with_qk_log2 and moving PV dequantization into finalize via scale_pv matches the P-quantization setup.

Given ODequantize is now intentionally a no-op with a comment about fusing into finalize, please ensure there are no remaining call sites that still expect it to perform the output dequantization directly.

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 5, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !151 has been updated with latest changes, and the CI pipeline #39692664 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
flashinfer/prefill.py (2)

2145-2154: Run path now respects cached output dtype; ensure FP8 scale args are clearly specified and q_scale documented.

  • Using _cached_o_data_type (falling back to q.dtype via getattr) for both allocation and dtype checks of out makes the run phase honor the o_data_type chosen at plan time, which is important for FP8 with FP16 output.
  • For non‑JIT FA3 FP8 paged kernels, the convention if is_float8(q) and len(args) >= 3: fp8_scale_q, fp8_scale_k, fp8_scale_v = args[:3] is reasonable, but it’s implicit. Consider documenting in the wrapper docstring that, in FP8 mode, the first three extra positional args are interpreted as scale_q, scale_k, scale_v and any additional args are ignored. Making them explicit keyword arguments on this API surface would be even clearer long‑term.

Also, q_scale is now used to fold into sm_scale here but is still not described in the run docstring, which has been called out previously.

Also applies to: 2222-2242, 2255-2256


2977-2980: Ragged run: output dtype wiring is good; still missing validation for FP8 scale tensors in *args.

  • Using _cached_o_data_type for ragged outputs (both allocation and checks) is aligned with the new plan API and is necessary for FP8 inputs with FP16 outputs.
  • The updated cast guard if is_float8(q) and self._backend != "fa3": ... correctly skips the FP8→FP16 conversion only for FA3, which now has native FP8 support.

The remaining issue is the FP8 scale handling at the bottom:

# For FP8, append scale tensors
if is_float8(q):
    run_args.extend(list(args))  # scale_q, scale_k, scale_v

As previously noted, there is still no validation that:

  • args contains exactly the three expected scale tensors, and
  • each element has a reasonable type/shape for a per‑head scale.

A malformed or incomplete args tuple will silently propagate garbage into the kernel arguments.

Consider tightening this by, for the non‑JIT path:

-        # For FP8, append scale tensors
-        if is_float8(q):
-            run_args.extend(list(args))  # scale_q, scale_k, scale_v
+        # For FP8, append (scale_q, scale_k, scale_v) explicitly
+        if is_float8(q):
+            if len(args) != 3:
+                raise ValueError(
+                    f"FP8 ragged prefill expects 3 scale tensors "
+                    f"(scale_q, scale_k, scale_v), got {len(args)}"
+                )
+            scale_q, scale_k, scale_v = args
+            for name, scale in (("scale_q", scale_q), ("scale_k", scale_k), ("scale_v", scale_v)):
+                if not isinstance(scale, torch.Tensor):
+                    raise TypeError(f"{name} must be a torch.Tensor, got {type(scale)}")
+            run_args.extend([scale_q, scale_k, scale_v])

This keeps the interface the same for correct callers but fails fast and clearly for misconfigured FP8 runs.

Also applies to: 2983-2987, 3005-3015, 3056-3058

🧹 Nitpick comments (3)
flashinfer/prefill.py (3)

69-86: Helper _split_scale_param behavior is sound; consider tightening typing/validation.

The tensor/scalar split logic looks correct and is reused consistently for FA3 FP16/FP8, but the helper currently accepts any non‑tensor object and blindly calls float(scale). You could make failures clearer and improve static checking by:

  • Adding a type hint, e.g. scale: Optional[Union[torch.Tensor, float, int]] -> Tuple[Optional[torch.Tensor], float].
  • Raising a TypeError with a clear message if scale is of an unsupported type instead of relying on float(scale) to fail.

446-452: FA3 ragged FP8 path: clarify feature support and tighten FP8 gating.

The new FA3 FP8 branch correctly reuses _split_scale_param and routes scale tensors/scalars separately, but two details are worth tightening:

  1. FP8 mode detection.
    is_fp8 = scale_q is not None means merely passing a non‑None scale_q flips to the FP8 kernel, even if q is not actually float8. It would be safer to gate on is_float8(q) (or both conditions) to avoid accidentally invoking the FP8 variant with FP16 inputs.

  2. Dropped mask / multi‑item args in FP8 branch.
    In the FP8 case you no longer thread maybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr, maybe_max_item_len_ptr, logits_soft_cap, rope_scale, rope_theta, or token_pos_in_items_len to ragged_run_func. If FA3 FP8 truly doesn’t support custom masks or multi‑item scoring yet, it would be safer to:

    • Either raise a NotImplementedError when those arguments are non‑None in FP8 mode, or
    • Document explicitly that these features are unsupported for FA3 FP8 ragged prefill and are ignored.

This avoids silent behavior differences between FP16 and FP8 runs.

Also applies to: 481-508, 510-513, 531-535


1617-1617: New o_data_type and get_seq_lens usage in paged plan are reasonable; watch kv_lens scope for TRTLLM.

  • Introducing o_data_type (defaulting to q_data_type then canonicalized) and caching it as _cached_o_data_type is a good way to support FP8 inputs with FP16 outputs while keeping FP16/BF16 behavior unchanged.
  • Switching the KV‑length computation to get_seq_lens(paged_kv_indptr_host, paged_kv_last_page_len_host, page_size) and copying into _kv_lens_buffer is consistent with the existing formula and should help with robust sizing for large block counts.
  • get_module_args now passes o_data_type into get_batch_prefill_module, so the compiled kernel sees the intended output dtype.

One small edge case: in the TRTLLM‑GEN branch you rely on kv_lens_arr_host (for blocks_per_seq), which is only defined in the max_sequence_kv is None path. The docs say max_sequence_kv is required only for the cuDNN backend, but if a caller ever provides it together with a TRTLLM‑GEN backend, kv_lens_arr_host will be undefined. Consider either:

  • Guarding with an explicit error if self._backend == "trtllm-gen" and max_sequence_kv is not None, or
  • Recomputing kv_lens_arr_host from seq_lens in that branch.

Also applies to: 1702-1705, 1709-1710, 1719-1720, 1756-1758, 1805-1814, 1889-1890, 1907-1908

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d06ea1a and 436c443.

📒 Files selected for processing (1)
  • flashinfer/prefill.py (32 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
  • get_seq_lens (178-201)
flashinfer/utils.py (2)
  • canonicalize_torch_dtype (241-249)
  • is_float8 (158-159)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/prefill.py (2)

311-349: FA3 single‑prefill scale splitting looks consistent with new kernel API.

For FA3, splitting scale_v (FP16) and scale_q/scale_k/scale_v (FP8) into (tensor, scalar) pairs via _split_scale_param and passing them as separate arguments matches the intended “per‑head tensor vs scalar factor” design and keeps non‑FA3 backends untouched. I don’t see correctness issues here.


692-747: Paged FA3 FP16/FP8 scale handling via _split_scale_param looks correct.

For the paged backend, splitting scale_v for FP16 and scale_q/scale_k/scale_v for FP8 into tensor/scalar components and threading them into paged_run_func keeps the call signature aligned with the FA3 kernels while preserving existing behavior for FA2/TRTLLM/CuDNN. I don’t see functional issues in this hunk.

Comment on lines +2572 to 2573
o_data_type: Optional[Union[str, torch.dtype]] = None,
non_blocking: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Pass dtype_o parameter to fmha_varlen for CUTLASS backend to honor the requested output dtype.

Currently, fmha_varlen() is called without the dtype_o parameter (line 3049), so it defaults to v.dtype. This differs from the FA2/FA3 path (line 3055), which explicitly passes dtype_o=self._cached_o_data_type. Since out is allocated with _cached_o_data_type (line 3029), there's a potential dtype mismatch when o_data_type differs from v.dtype in the CUTLASS path. Either enforce o_data_type == v.dtype for CUTLASS, or pass dtype_o=self._cached_o_data_type to fmha_varlen to ensure consistency with the FA2/FA3 path and the buffer allocation.

🤖 Prompt for AI Agents
In flashinfer/prefill.py around lines 2572-2573 (see related calls around lines
3029, 3049, 3055), the CUTLASS fmha_varlen call omits the dtype_o argument
causing a potential dtype mismatch with the preallocated out buffer; modify the
CUTLASS path call to pass dtype_o=self._cached_o_data_type to fmha_varlen so the
output dtype matches the allocated out tensor (or alternatively add an explicit
assert that self._cached_o_data_type == v.dtype before calling fmha_varlen if
you intend to enforce identical dtypes).

@yzh119 yzh119 enabled auto-merge (squash) December 6, 2025 03:31
@yzh119 yzh119 disabled auto-merge December 6, 2025 03:31
@yzh119 yzh119 merged commit 09872a1 into flashinfer-ai:main Dec 6, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error "_vector_sparse_indices_buffer is not large enough" for VariableBlockSparseAttention

4 participants