-
Notifications
You must be signed in to change notification settings - Fork 593
refactor: update fa3 codebase and fix hopper unittest [part 1] #2111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds page-level K/V page_stride fields and runtime stride validations for sparse paged KV cache; replaces block→vector-sparse conversion with page-table/manual page addressing; threads FP8 scale tensors and output dtype through JIT/host/kernel paths; removes block-sparse helper APIs and updates tests/benchmarks accordingly. Changes
Sequence Diagram(s)sequenceDiagram
participant Tests as Python tests/bench
participant Wrapper as Wrapper (flashinfer/prefill.py / sparse.py)
participant JIT as JIT modules (flashinfer/jit/...)
participant Host as Host planner (csrc/* + headers)
participant Kernel as CUDA device mainloop
Note over Tests,Kernel: New paged/ragged + FP8 flow
Tests->>Wrapper: run(plan/run args, paged_kv_indptr, paged_kv_indices, o_data_type?, fp8_scales?)
Wrapper->>JIT: plan/run(..., paged buffers, maybe_scale_*, o_data_type)
JIT->>Host: module args include k_page_stride, v_page_stride, k_stride_n, v_stride_n, page_size, scale tensors/scalars
Host->>Kernel: launch kernel with (K_ptr, V_ptr, kv_indices, k_page_stride, v_page_stride, page_size, scales...)
Kernel->>Kernel: compute page_idx via divmod(page_size) → page offsets
Kernel->>Kernel: prefetch / cp_async using page strides and stride_n
Kernel->>Kernel: MMA / epilogue (barrier/sync) → write output (obey o_data_type & scales)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the efficiency and correctness of the Flash Attention v3 (FA3) implementation, particularly for paged Key-Value (KV) caches with page sizes greater than one. By integrating page offset calculations directly into the kernel and optimizing KV offset handling with prefetching and shuffling, the codebase becomes more streamlined and performant. A critical bug affecting Hopper unittests has also been resolved, ensuring robust operation on the target architecture. These changes collectively contribute to a more optimized and reliable sparse attention mechanism. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant refactoring of the FA3 codebase. It removes the standalone block_sparse_indices_to_vector_sparse_offsets function and moves the page offset calculation directly into the CUDA kernel, which is a great simplification. The changes also include an optimization for kv_offset calculation using prefetching and shuffling, which should improve performance. The code removal across C++, Python, and header files is consistent and clean. I've found a couple of minor areas for code improvement to reduce redundancy, but overall the changes look solid and well-implemented.
| int d_idx = get<1>(coord); | ||
| int kv_idx = kv_base_idx + kv_offset; | ||
|
|
||
| bool guard = kv_idx < kv_len && kv_offset < valid_tile_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The guard condition can be simplified. The check kv_idx < kv_len is redundant when use_predicate is true, as it's already implied by kv_offset < valid_tile_size. When use_predicate is false, valid_tile_size is CTA_KV, and kv_offset is always less than CTA_KV, so the guard is not needed for non-last tiles anyway. You can simplify this to just kv_offset < valid_tile_size.
bool guard = kv_offset < valid_tile_size;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
flashinfer/prefill.py (1)
2109-2156: Paged KV run argument rewiring is reasonable; verify trtllm cum_seq_lens_kv semanticsUsing
_paged_kv_indptr_buf/_paged_kv_indices_bufdirectly inrun_argskeeps the Python wrapper aligned with the new paged-KV FFI signature, and_qo_indptr_bufis a natural fit forcum_seq_lens_q. The only subtle point is that_paged_kv_indptr_bufis in units of pages, while trtllm paged attention APIs traditionally expectcum_seq_lens_kvin tokens; if the trtllm-gen backend actually consumes those trailing args as cum-token lengths, it may needcumsum(seq_lens)instead of raw page indptr. Worth double-checking against the current trtllm kernel contract.tests/attention/test_batch_prefill_kernels.py (1)
147-157: Good coverage of preallocated LSE path; consider also checking LSE valuesUsing
lse_buffer = torch.empty_like(lse)and rerunning without=o_buffer, lse=lse_buffernow exercises the buffered LSE write path, which should catch the Hopper regression. To fully validate it, you may also want to asserttorch.testing.assert_close(lse, lse_buffer, ...)alongside the existingovso_buffercheck.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
csrc/batch_prefill_sm90.cu(1 hunks)csrc/batch_prefill_sm90_customize_config.jinja(1 hunks)csrc/flashinfer_page_binding.cu(0 hunks)csrc/page.cu(0 hunks)flashinfer/page.py(0 hunks)flashinfer/prefill.py(3 hunks)flashinfer/sparse.py(2 hunks)include/flashinfer/attention/hopper/default_params.cuh(1 hunks)include/flashinfer/attention/hopper/prefill_sm90.cuh(1 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(8 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(1 hunks)include/flashinfer/attention/hopper/sparse_mainloop.cuh(8 hunks)include/flashinfer/page.cuh(0 hunks)tests/attention/test_batch_prefill_kernels.py(1 hunks)tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py(0 hunks)
💤 Files with no reviewable changes (5)
- csrc/flashinfer_page_binding.cu
- csrc/page.cu
- flashinfer/page.py
- tests/utils/test_block_sparse_indices_to_vector_sparse_offsets.py
- include/flashinfer/page.cuh
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/batch_prefill_sm90.cu (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
TVM_FFI_ICHECK_EQ(167-171)TVM_FFI_ICHECK_EQ(283-286)
tests/attention/test_batch_prefill_kernels.py (1)
flashinfer/prefill.py (6)
run(1924-1936)run(1939-1951)run(1953-2166)run(2768-2778)run(2781-2791)run(2793-2939)
flashinfer/prefill.py (1)
flashinfer/page.py (1)
get_seq_lens(176-199)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)
102-108: LGTM: Correct handling of separate K and V strides.This implementation correctly supports different memory layouts for K and V:
- Parameterized design: The
load_kv_tilelambda (lines 232-267) acceptsstride_nandpage_strideas parameters rather than hardcoding them- Separate calls: K and V loads pass their respective strides:
- K:
load_kv_tile(k_base_ptr, k_stride_n, k_page_stride, ...)(line 275)- V:
load_kv_tile(v_base_ptr, v_stride_n, v_page_stride, ...)(line 298)- Flexible addressing: Line 259 computes offsets using the passed-in parameters
This is the correct pattern for page-based sparse loading and avoids the stride assumption issue present in
sparse_mainloop.cuh.Also applies to: 118-124, 232-267
include/flashinfer/attention/hopper/sparse_mainloop.cuh (1)
110-112: Stride equality is already validated on the host side; v_page_stride is intentionally passed through for API consistency.The
v_page_strideparameter, while unused in the non-quantizedsparse_mainloop.cuhkernel, is not a bug. An assertion incsrc/batch_prefill_sm90.culine 235 validates that K and V page strides are equal at runtime, and the comment in the sparse mainloop (line 281) explicitly documents this assumption. Theprefetch_kv_offsetlambda correctly reuses the same offset computation for both K and V loads.The parameter exists for API consistency with the quantized variant (
mainloop_sparse_load.cuh), which does usev_page_strideseparately. If API unification across quantized and non-quantized paths is intentional, no action is needed.include/flashinfer/attention/hopper/default_params.cuh (1)
157-160: k_page_stride / v_page_stride fields look consistentAdding explicit page-stride fields after
nnz_qomatches the other Hopper paged params structs and keeps types/ordering coherent with the new sparse mainloop arguments.include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)
337-344: FP8 sparse mainloop argument rewiring looks correctSwitching K/V from
get_gmem_layoutto explicit{k_stride_n, k_page_stride, v_stride_n, v_page_stride, kv_indices, page_size}matches the updated sparse mainloop API and keeps Q/O layout handling unchanged.flashinfer/prefill.py (1)
36-36: Importingget_seq_lensis appropriateThis import matches the later use of
get_seq_lensinBatchPrefillWithPagedKVCacheWrapper.planto derive KV sequence lengths from paged metadata.include/flashinfer/attention/hopper/prefill_sm90.cuh (1)
382-386: Sparse prefill mainloop now correctly receives KV indices and paging metadataPassing
kv_indices,window_left,k_page_stride,v_page_stride, andpage_sizeintoSparseCollectiveMainloop::to_underlying_argumentslines up with the new paged sparse mainloop contract and keeps Q/K/V layouts unchanged.csrc/batch_prefill_sm90_customize_config.jinja (1)
107-111: PagedParams gains explicit K/V page strides in the right placeAdding
k_page_stride/v_page_strideafternnz_qokeeps this JIT-generated PagedParams struct aligned with the Hopper default params and with how batch_prefill_sm90.cu now fills these fields frompaged_{k,v}_cache.stride(0).csrc/batch_prefill_sm90.cu (1)
221-238: Page-stride wiring and K/V stride consistency checks make senseRecording
k_page_stride/v_page_stridefromstride(0)in both layouts and then asserting that K/V share the same page stride andstride_nis a good guardrail for the sparse paged mainloop; it will surface mis-laid-out KV caches early with clear error messages rather than letting the kernel access mismatched layouts.
| int64_t my_kv_offset[2]; // Rolling buffer: page_idx * page_stride + entry_idx * stride_n | ||
|
|
||
| // Group organization based on partition strategy | ||
| constexpr int NUM_KV_PER_ITER = decltype(size<1>(tKcK))::value; // e.g., 12 | ||
| constexpr int KV_STRIDE = CTA_KV / NUM_KV_PER_ITER; // 96/12 = 8 | ||
| constexpr int NUM_GROUPS = KV_STRIDE; // 8 groups (one per lane) | ||
| constexpr int THREADS_PER_GROUP = NUM_COPY_THREADS / NUM_GROUPS; // 128/8 = 16 | ||
| constexpr int NUM_ITERS_PER_GROUP = NUM_KV_PER_ITER; // 12 iterations per group | ||
|
|
||
| int group_id = thread_idx / THREADS_PER_GROUP; // 0-7 | ||
| int thread_in_group = thread_idx % THREADS_PER_GROUP; // 0-15 | ||
|
|
||
| // Prefetch: compute page_idx * page_stride + entry_idx * stride_n | ||
| // NOTE: Assumes K and V have same strides (asserted on host side) | ||
| auto prefetch_kv_offset = [&](int kv_tile_idx, bool use_predicate) { | ||
| int kv_base_idx = kv_tile_idx * CTA_KV; | ||
| int buf_idx = kv_tile_idx % 2; | ||
|
|
||
| int kv_idx_read = kv_base_idx + group_id + thread_in_group * KV_STRIDE; | ||
| bool valid_read = | ||
| thread_in_group < NUM_ITERS_PER_GROUP && (!use_predicate || kv_idx_read < kv_len); | ||
|
|
||
| if (valid_read) { | ||
| // Use divmod to find page and offset within page | ||
| uint32_t page_iter, entry_idx; | ||
| mainloop_params.page_size.divmod(kv_idx_read, page_iter, entry_idx); | ||
| IdType page_idx = kv_indices_ptr[page_iter]; | ||
| // Pre-compute: page_idx * page_stride + entry_idx * stride_n | ||
| my_kv_offset[buf_idx] = page_idx * k_page_stride + entry_idx * k_stride_n; | ||
| } else { | ||
| my_kv_offset[buf_idx] = 0; | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefetch logic assumes K and V have identical strides.
The prefetch_kv_offset lambda computes my_kv_offset using only K strides (k_page_stride and k_stride_n on line 296), but this offset is later reused for both K and V loads in load_kv_with_gather. This hardcodes the assumption that K and V have identical memory layouts.
Compare with mainloop_sparse_load.cuh (lines 232-267), which correctly uses separate stride parameters in its load_kv_tile lambda, allowing K and V to have different layouts.
Consider refactoring to either:
- Option 1: Compute separate offsets for K and V if they can differ
- Option 2: Use a single set of stride parameters if layouts must be identical
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 268-300,
the prefetch lambda computes my_kv_offset using only K strides but the same
offset is later used for both K and V loads, incorrectly assuming identical K/V
layouts; fix by computing distinct offsets for K and V (or enforce
identical-layout at compile/runtime). Update the lambda to accept/use separate
stride parameters (e.g., k_page_stride/k_stride_n and v_page_stride/v_stride_n)
and write into two rolling buffers (my_kv_offset_k[2] and my_kv_offset_v[2]) so
load_kv_with_gather can use the correct offset for each tensor, or alternatively
add a clear static_assert/runtime check and comment that K and V must share
strides and keep single offset.
| auto load_kv_with_gather = [&](auto&& tXsX, auto&& tXcX, DTypeKV* base_ptr, int kv_tile_idx, | ||
| int stage_idx, bool use_predicate) { | ||
| using Vec = AlignmentTypeKV; | ||
| constexpr int VecSize = sizeof(Vec) / sizeof(DTypeKV); | ||
|
|
||
| int kv_base_idx = kv_tile_idx * CTA_KV; | ||
| int buf_idx = kv_tile_idx % 2; | ||
|
|
||
| auto dst = recast<Vec>(flatten(tXsX(_, _, _, stage_idx))); | ||
| auto c = flatten(tXcX(_, _, _, kv_tile_idx)); | ||
|
|
||
| constexpr unsigned FULL_MASK = 0xffffffff; | ||
|
|
||
| // Load using FA3-style shuffle with pre-computed offsets | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < size(dst); ++i) { | ||
| auto coord = c(VecSize * i); | ||
| int kv_offset = get<0>(coord); | ||
| int d_idx = get<1>(coord); | ||
| int kv_idx = kv_base_idx + kv_offset; | ||
| bool guard = !use_predicate || kv_idx < kv_len; | ||
|
|
||
| // Shuffle the pre-computed offset (page_idx * page_stride + entry_idx * stride_n) | ||
| int src_thread = group_id * THREADS_PER_GROUP + kv_offset / KV_STRIDE; | ||
| int64_t base_offset = __shfl_sync(FULL_MASK, my_kv_offset[buf_idx], src_thread); | ||
|
|
||
| // Final address: base_ptr + base_offset + d_idx | ||
| // where base_offset = page_idx * page_stride + entry_idx * stride_n | ||
| Vec const* src_ptr = reinterpret_cast<Vec const*>(base_ptr + base_offset + d_idx); | ||
| cutlass::arch::cp_async_zfill<sizeof(Vec), cutlass::arch::CacheOperation::Global>( | ||
| &dst(i), src_ptr, guard); | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_kv_with_gather reuses K offsets for V loads.
The load_kv_with_gather helper shuffles and reuses my_kv_offset (computed using K strides in prefetch_kv_offset) for both K and V loads:
- Line 341:
load_kv_with_gather(..., K_ptr_base, ...) - Line 367:
load_kv_with_gather(..., V_ptr_base, ...)
This shuffle-based optimization is effective for performance but requires K and V to have identical page strides and per-token strides. If this constraint is enforced elsewhere, add an assertion or comment clarifying why separate v_page_stride parameters exist but are unused.
For reference, mainloop_sparse_load.cuh avoids this issue by passing stride parameters explicitly to its load_kv_tile helper.
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/sparse_mainloop.cuh around lines 303 to
335, load_kv_with_gather reuses the K offsets (my_kv_offset) for V loads which
is only valid if K and V have identical page and per-token strides; update the
code to either (A) assert at runtime (or static_assert / debug check) that
v_page_stride == k_page_stride and per-token strides match and add a clear
comment explaining why v_page_stride parameter is unused, or (B) change the
caller/implementation so V uses its own computed offsets (compute a separate
my_v_offset in prefetch_v_offset and shuffle that for V loads) so K and V can
have different strides—pick one approach and apply consistently (add the
assertion/comment if you choose A; implement separate offset computation and use
it in the shuffle and cp_async_zfill calls if you choose B).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)
476-591: LGTM!The new ragged KV dispatch functions are correctly implemented:
- Uses TMA load for contiguous ragged memory (consistent with single prefill)
- Proper layout construction and scheduler setup
- Head dimension dispatch covers all supported values (64, 128, 256)
Minor: The comment on line 497 ("NOTE(Zihao): nnz was useless here, we can just pass 0") reads as a debug/TODO note. Consider removing or rephrasing if the implementation is finalized.
flashinfer/prefill.py (1)
416-472: Consider aligning FP8 detection with tensor dtype check.The FP8 detection here uses
scale_q is not None(line 421), while other places in the codebase useis_float8(q). This could lead to inconsistency if:
- FP8 input is provided without scale tensors
- Non-FP8 input is accidentally provided with scale tensors
Consider using
is_float8(q)for consistency, or add a validation that ensures FP8 inputs always have scale tensors.- # Check if FP8 by presence of scale tensors - is_fp8 = scale_q is not None + # Check if FP8 by tensor dtype + is_fp8 = is_float8(q) + if is_fp8 and scale_q is None: + raise ValueError("FP8 inputs require scale_q, scale_k, scale_v tensors")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja(1 hunks)csrc/batch_prefill_fp8_sm90.cu(3 hunks)flashinfer/prefill.py(21 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(8 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(2 hunks)tests/attention/test_hopper_fp8_attention.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
get_seq_lens(176-199)flashinfer/utils.py (3)
canonicalize_torch_dtype(240-248)check_shape_dtype_device(519-537)is_float8(157-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (17)
tests/attention/test_hopper_fp8_attention.py (3)
186-280: LGTM!The test function is well-structured, following the established pattern for FP8 testing. It correctly:
- Creates variable-length sequences for batch prefill
- Generates FP16 reference output
- Quantizes inputs to FP8
- Compares MSE between FP16 and FP8 paths
283-403: LGTM!The paged KV cache test is correctly implemented:
- Proper page allocation and indptr/indices construction
- Appropriate reshape-quantize-reshape pattern for paged KV tensors
- Consistent with the ragged test structure
406-426: LGTM!The
__main__block updates provide convenient local test execution with a reasonable subset of parameters for quick validation.include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (1)
337-344: LGTM!The parameter changes correctly pass stride and page-size information directly to the sparse mainloop, aligning with the PR's objective of moving page offset calculation into the kernel.
csrc/batch_prefill_fp8_sm90.cu (2)
86-173: LGTM!The
BatchPrefillWithRaggedKVCacheSM90Runimplementation is well-structured:
- Proper plan info initialization and LSE validation
- Correct layout-aware stride handling for NHD/HND
- Appropriate static assertions for FP8 constraints
- Consistent error handling pattern
231-243: LGTM!The page stride handling is correct. Using
stride(0)consistently retrieves the stride between pages regardless of the internal layout (NHD or HND), which is the intended behavior for sparse paged KV cache addressing.flashinfer/prefill.py (5)
1566-1567: LGTM!The
o_data_typeparameter addition is well-implemented with proper canonicalization and caching for use in the run method.
2092-2102: LGTM!The output allocation correctly uses the cached output data type with a safe fallback to
q.dtypefor backward compatibility.
2950-2959: LGTM!The FP8 handling correctly bypasses the FP16 conversion for FA3 backend while maintaining backward compatibility with FA2 backend (which still shows a deprecation warning and converts to FP16).
3001-3003: LGTM!The FP8 scale tensor extension follows the established pattern from the paged path.
2170-2189: I'll verify the FP8 scale tensor extraction from*argsby examining how callers pass FP8 scale tensors to therun()method and checking the documentation.
<function_calls>
#!/bin/bashSearch for calls to .run() with scale arguments to understand the convention
rg -n --type=py ".run(" flashinfer/ tests/ -A 3 -B 2 | grep -i "scale|fp8" -A 5 -B 5 | head -60
</function_calls>Let me search more specifically for the run method definition and how it's called with FP8 scales:
<function_calls>
#!/bin/bashLook at the run() method signature and docstring in prefill.py
rg -n "def run(" flashinfer/prefill.py -A 20 | head -80
</function_calls><function_calls>
#!/bin/bashSearch for FP8 scale passing in test files and examples
rg -n "fp8.*scale|scale.*fp8" flashinfer/ tests/ --type=py -i -B 3 -A 3 | head -100
</function_calls><function_calls>
#!/bin/bashLook for is_float8 checks and how they're used with scale tensors
rg -n "is_float8" flashinfer/ --type=py -B 5 -A 10 | head -150
</function_calls><function_calls>
#!/bin/bashFind the run method definition around line 2170
sed -n '2100,2250p' flashinfer/prefill.py | cat -n
</function_calls><function_calls>
#!/bin/bashFind the class that contains this run method and its docstring
sed -n '1800,2200p' flashinfer/prefill.py | tail -400 | head -300 | cat -n
</function_calls><function_calls>
#!/bin/bashLook for the method definition that contains line 2170
rg -n "def run" flashinfer/prefill.py --type=py -B 2 | grep -E "^\d+.*def run"
</function_calls><function_calls>
#!/bin/bashGet context around line 2170 to find which method it belongs to
head -2200 flashinfer/prefill.py | tail -300 | cat -n
</function_calls><function_calls>
#!/bin/bashSearch for the exact pattern: fp8_scale_q = args pattern
rg -n "fp8_scale_q.*args|args.*fp8_scale" flashinfer/ --type=py -B 10 -A 15
</function_calls><function_calls>
#!/bin/bashLet me get the actual content around line 2170-2189 with broader context
sed -n '2100,2250p' flashinfer/prefill.py | cat -n
</function_calls>include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (6)
102-111: LGTM! Clear page-based addressing structure.The new stride and page fields provide a clean interface for page-based K/V tensor addressing, replacing the previous layout-based approach.
118-127: LGTM! Efficient use of fast division.Using
uint_fastdivforpage_sizeenables efficient divmod operations in the hot path.
134-137: LGTM! Parameter forwarding is correct.All new stride and page parameters are properly forwarded from Arguments to Params.
212-231: LGTM! Clean setup for page-based loading.The coordinate tensor partitioning and parameter extraction properly prepare for the manual K/V loading path.
273-372: LGTM! Proper tile loading sequence and synchronization.The tile loading pattern correctly applies predication only to the last tile while intermediate tiles load without bounds checking. Pipeline synchronization, V transpose coordination, and barrier usage are all properly structured.
232-266: I encountered a repository clone failure and cannot access the codebase to verify the bounds checking concern. However, I can provide a rewritten review comment based on the code snippet provided and sound software engineering analysis:Verify bounds check for page table access before accessing kv_indices_ptr.
The lambda correctly implements page-based addressing for K/V tiles, but the code at line 257 lacks validation that
page_iteris within the bounds ofkv_indices_ptrbefore array access. Whenpage_iteris computed viadivmod(kv_idx, page_size), the result could potentially exceed the allocated size of the page table if:
- The page table was sized based on an incorrect upper bound for KV entries
- Concurrent modifications affect the array size
- Off-by-one errors exist in the page table allocation logic
To resolve this:
- Add an assertion or bounds check:
CUTE_ASSERT(page_iter < num_pages)before line 257, or- Verify that the page table allocation logic guarantees sufficient capacity for all possible
page_itervalues derived from validkv_idxvalues- Document the invariant that
kv_indices_ptrmust accommodateceil_div(kv_len, page_size)entriesRegarding the guard condition at line 252 (
!use_predicate || kv_idx < kv_len): This appears correct for intermediate vs. final tiles, but validate thatkv_lenaccurately reflects the maximum valid index.
| } | ||
| }; | ||
|
|
||
| int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or utilize the unused variable.
The variable valid_last_kv_tile_size is computed but never used in the subsequent code. If the intention is to optimize the guard condition in load_kv_tile (as suggested in the past review comment), this value should be passed to the lambda. Otherwise, this line should be removed.
Apply this diff if the variable is not needed:
- int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
-Or, if you want to use it for optimized bounds checking, update the load_kv_tile signature to accept it:
- auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
- int tile_idx, int pipe_idx, bool use_predicate) {
+ auto load_kv_tile = [&](DTypeKV* base_ptr, int64_t stride_n, int64_t page_stride, auto& tXsX,
+ int tile_idx, int pipe_idx, int valid_tile_size) {And update the guard condition accordingly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV); |
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh
around line 268, the local variable valid_last_kv_tile_size is computed but
never used; either remove this line if not needed, or update load_kv_tile to
accept an extra parameter for the valid tile size and pass
valid_last_kv_tile_size into the lambda so it can replace the existing generic
guard logic. If choosing the second option, change the load_kv_tile signature to
take an int valid_tile_size, update all call sites, and use that value inside
the lambda for optimized bounds checking; otherwise simply delete the unused
variable declaration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)
318-425: Critical: V loads reuse K's prefetch offsets with potentially different strides.The V loading at lines 345, 363, and 390 reuses prefetch offsets computed for K (with
k_stride_nandk_page_stride), but V should usev_stride_nandv_page_stride. This is evident from line 411, which explicitly prefetches V withv_stride_nandv_page_stride.If K and V have different strides or page strides, V will be loaded from incorrect addresses, causing data corruption.
The API explicitly provides separate stride parameters for K and V (Arguments and Params structs), suggesting they can differ. Either:
- Add prefetch calls for V before each V load (lines 345, 363, 390) using
v_stride_nandv_page_stride, OR- Document and assert that
k_stride_n == v_stride_nandk_page_stride == v_page_stridemust holdApply this pattern to fix the V loads:
if (kv_tile_idx == swa_begin_kv_tile_idx) { - // first tile is the last tile, reuse kv_tile_idx prefetch for V + // first tile is the last tile, prefetch for V + prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true); pipeline_v.producer_acquire(smem_pipe_write); load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);} else { // load second last k-tile and last v-tile // Prefetch for next K tile (kv_tile_idx - 1) prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false); - // Load V using prefetch from last K load (kv_tile_idx) + // Prefetch and load V for kv_tile_idx + prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, true); pipeline_v.producer_acquire(smem_pipe_write); load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), true);for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { // Prefetch for next K tile prefetch_kv_offset(kv_tile_idx - 1, k_stride_n, k_page_stride, false); - // Load V using prefetch from previous K prefetch + // Prefetch and load V for kv_tile_idx + prefetch_kv_offset(kv_tile_idx, v_stride_n, v_page_stride, false); pipeline_v.producer_acquire(smem_pipe_write); load_kv_with_prefetch(v_base_ptr, tVsV, kv_tile_idx, smem_pipe_write.index(), false);
♻️ Duplicate comments (1)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)
316-316: Remove the unused variable.As noted in the previous review,
valid_last_kv_tile_sizeis computed but never used.
🧹 Nitpick comments (4)
benchmarks/bench_hopper_fp8_attention.py (2)
216-216: Document or validate page_size divisibility assumption.Line 216 assumes
seq_lenis perfectly divisible bypage_size. While the current test cases satisfy this (seq_len ∈ {1024, 2048, 4096, 8192} with page_size=16), the function might be called with other parameters in the future.Consider adding a validation check:
+ assert seq_len % page_size == 0, f"seq_len ({seq_len}) must be divisible by page_size ({page_size})" num_pages = batch_size * seq_len // page_size
250-251: Consider making workspace buffer size configurable.The 256MB workspace buffer is hardcoded for both FP16 and FP8 wrappers. While sufficient for current benchmark sizes, this might be inadequate for larger workloads or future test expansions.
Consider either:
- Making workspace size a parameter with a reasonable default
- Adding a comment documenting the size assumption
- Having the wrappers handle workspace allocation internally if supported
This is a minor point since the current sizes work for the benchmarks being run.
Also applies to: 268-269
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)
331-344: Paged KV mainloop param wiring looks consistentThe new argument list (k/v strides, page stride, kv_indices, page_size) lines up with a paged/sparse K/V mainloop and matches the scheduler/block_coord usage in this kernel. From this file’s perspective the wiring looks correct; no blocking issues.
If
Params::page_sizeis not already a 32‑bit type, consider documenting or static‑asserting the expected range to make theuint32_tcast here self‑evident to future readers.
477-550: Ragged KV kernel‑traits dispatch wiring looks correct; stale commentThe ragged‑KV kernel‑traits dispatch correctly switches to
FP8CollectiveMainloopand reuses the BatchPrefill schedulers/arguments in the same way as the paged path, with Q/K/V layouts built viaget_gmem_layout, so the host→device params plumbing looks coherent.The comment on Line 499 saying “nnz was useless here, we can just pass 0” now contradicts the actual
params.nnz_kvargument; consider updating or removing this note to avoid confusion about whether the first dimension of the K/V layout is semantically meaningful.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/bench_hopper_fp8_attention.py(4 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(7 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_hopper_fp8_attention.py (3)
flashinfer/testing/utils.py (2)
bench_gpu_time(985-1046)attention_tflops_per_sec_with_actual_seq_lens(421-454)benchmarks/bench_block_sparse_attention.py (1)
flops(125-134)benchmarks/bench_hopper_attention.py (3)
flops(46-55)flops(107-116)flops(187-196)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (16)
benchmarks/bench_hopper_fp8_attention.py (7)
27-38: LGTM: Correct per-head symmetric quantization implementation.The quantization logic correctly handles both FP8 formats with appropriate ranges, computes per-head scales by taking max over dimensions (0, 2), and includes defensive clamping to prevent division by zero.
41-108: LGTM: Well-structured FP8 single prefill benchmark.The benchmark correctly creates FP16 baseline tensors, quantizes them to FP8 with per-head scales, measures both paths using median GPU time, and reports meaningful performance metrics with speedup calculations.
111-201: LGTM: Correct batch ragged prefill benchmark implementation.The ragged batch benchmark properly constructs indptr arrays for batch boundaries, configures wrappers with appropriate data types, and correctly passes quantization scales to the FP8 execution path.
233-238: LGTM: Correct paged KV quantization strategy.Flattening the paged KV cache for quantization and then reshaping back is the right approach to maintain per-head quantization semantics across all pages while preserving the paged memory layout.
240-247: LGTM: Correct indptr and page table setup.The indptr arrays and page indices are correctly constructed:
qo_indptrmarks query batch boundaries (everyseq_lentokens)kv_indptrmarks page batch boundaries (everyseq_len // page_sizepages)kv_indicesprovides sequential page mappinglast_page_lenassumes full pages, which is appropriate for uniform benchmark workloads
330-336: Clarify status of skipped single prefill benchmarks.The single prefill benchmarks are commented out due to "compilation issues." Given the PR objectives mention fixing a failing Hopper unittest, is this related?
Please clarify:
- Are these compilation issues expected to be resolved in this PR or a follow-up?
- Should this be tracked with a TODO or issue reference?
- Is this related to the unittest fixes mentioned in the PR description?
342-356: LGTM: Comprehensive benchmark coverage.The test configurations provide good coverage across different:
- Head dimensions (128, 256)
- Batch sizes (16-128)
- Sequence lengths (1024-8192)
- Both ragged and paged KV cache layouts
The parameter combinations maintain roughly constant total token counts, which is sensible for comparing performance across configurations.
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (7)
102-108: LGTM: Page-based KV cache parameters added.The addition of separate stride and page_stride parameters for K and V tensors, along with page_size, correctly supports the refactored page-based KV loading scheme.
118-124: LGTM: Efficient fastdiv used for page_size.Using
uint_fastdivforpage_sizeenables efficient divmod operations in the kernel hot path.
134-137: LGTM: Parameter forwarding is correct.All new page-based parameters are correctly forwarded from Arguments to Params.
212-220: LGTM: Manual K/V loading setup is complete.All required parameters for page-based K/V loading are correctly extracted and prepared.
232-259: LGTM: Well-documented thread organization for FA3-style prefetch.The rolling buffer prefetch scheme and detailed thread organization comments are helpful for understanding this complex optimization. The NUM_KV_PER_ITER calculations appear correct.
260-280: LGTM: Page-based offset prefetch is correctly implemented.The divmod-based page addressing and rolling buffer management are correctly implemented. The offset computation properly combines page-level and entry-level addressing.
282-314: LGTM: Shuffle-based offset loading is correctly implemented.The shuffle-based offset sharing and cp_async_zfill with guard correctly implement the FA3-style optimized loading pattern.
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)
462-476: CTA_KV=64 for HEAD_DIM=256 paged path seems reasonable; please benchmarkReducing
CTA_KVfrom 128→64 for the sparse paged path (with the accompanying comment about 64×64 FP8 transpose minimum) is a plausible trade‑off to cut page‑table lookups; launch shape and error handling remain consistent with other HEAD_DIM branches.Please sanity‑check perf/occupancy for HEAD_DIM=256 on Hopper (especially long‑seq FA3 workloads) to ensure this smaller CTA_KV doesn’t introduce regressions compared to the previous configuration.
552-592: New BatchFP8PrefillWithRaggedKVCacheDispatched entrypoint matches existing patternsThis wrapper mirrors the single‑batch FP8 dispatch: HEAD_DIM specializations,
USE_TMA_LOAD_KV=truefor ragged K/V, and the same error‑reporting pattern as the paged variant. The trait choices (CTA_Q/CTA_KV/NUM_STAGES) are consistent with the non‑ragged FP8 paths.Once the ragged‑KV tests are in place, it’d be good to run them for all HEAD_DIM (64/128/256) with large nnz_qo/nnz_kv configurations comparable to issue #1647 to confirm this new batch entrypoint behaves as expected on Hopper.
|
/bot run |
|
[CANCELING] Pipeline #39439259: canceled |
| enable_pdl = device_support_pdl(q.device) | ||
| k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) | ||
| _check_cached_qkv_data_type( | ||
| q, k_cache, self._cached_q_data_type, self._cached_kv_data_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add docs for q_scale
| fp8_scale_q = None | ||
| fp8_scale_k = None | ||
| fp8_scale_v = None | ||
| if is_float8(q) and len(args) >= 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not a fan of making these are implicit positional args. Instead, I think we should document it clearly and make them explicit.
For example, we can make k_scale : Optional[torch.Tensor, float] or have separate keyword args like k_scale_device. Also, we should document clearly that these scale tensors should be per-head. Otherwise, illegal mem access will occur.
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh
Outdated
Show resolved
Hide resolved
| None, # scale_q, not supported yet | ||
| None, # scale_k | ||
| None, # scale_v | ||
| fp8_scale_q, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the FA3 FP8 kernels, I suggest that we change scale_q/scale_k/scale_v to maybe_scale_q/maybe_scale_k/maybe_scale_v
See a8d9e6a as an example about how to do this.
| o_data_type=o_dtype, | ||
| causal=causal, | ||
| ) | ||
| o_fp8 = wrapper_fp8.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have tests that cover both types of scaling factors:
- host-side per-tensor scales
- device-side per-head scale tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (2)
227-234: Verify head offset calculation is correct for all KV cache layouts.A past reviewer flagged the base pointer calculation as an accuracy bug. The current implementation uses
kv_head_idx * k_stride_hfor the head offset. Please confirm thatk_stride_hcorrectly represents the stride between heads in elements for all supported KV cache layouts (NHD vs HND), and that this matches the stride values passed from the Python/dispatch layer.Run the following script to verify how stride_h is computed and passed to this kernel:
#!/bin/bash # Search for where k_stride_h is computed and passed to FP8SparseCollectiveMainloop rg -n -C3 'k_stride_h|v_stride_h' --type=cpp --type=cu
331-331: Remove unused variable.The variable
valid_last_kv_tile_sizeis computed but never used in the subsequent code. Remove this line to eliminate dead code.Apply this diff:
- int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV); -
🧹 Nitpick comments (4)
flashinfer/prefill.py (2)
2100-2108: Inconsistent_cached_o_data_typeaccess pattern.Using
getattr(self, "_cached_o_data_type", None)suggests_cached_o_data_typemay not always be set, butplan()always sets it (lines 1711-1713, 2645-2647). Consider usingself._cached_o_data_typedirectly for consistency with the rest of the codebase, or document why the defensive access is needed.- # Use cached output data type if available (for FP8 attention with FP16 output) - out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype + # Use cached output data type (for FP8 attention with FP16 output) + out_dtype = self._cached_o_data_type out = torch.empty( q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device ) else: - out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype + out_dtype = self._cached_o_data_type
2177-2196: FP8 scale tensor extraction from*argsis fragile.The code assumes
args[0:3]are scale tensors whenis_float8(q)is true, but there's no validation that these are actually tensor types or have the expected shapes. A mismatch could cause silent incorrect behavior or cryptic errors.Consider adding explicit keyword arguments for FP8 scales in the paged
run()method signature (similar to howq_scale,k_scale,v_scaleare already present) instead of extracting from*args. This would make the API explicit and type-safe.csrc/batch_prefill_fp8_sm90.cu (1)
103-104: Inconsistent device guard usage.
BatchPrefillWithRaggedKVCacheSM90RunusescudaSetDevice()directly (line 103), whileBatchPrefillWithPagedKVCacheSM90Runusesffi::CUDADeviceGuard(line 206). The RAII pattern withCUDADeviceGuardis safer as it restores the device on scope exit.- cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);tests/attention/test_hopper_fp8_attention.py (1)
347-388: Remove unusedhead_dimparameter.The
head_dimparameter is not used in the function body since the dimension information is already contained in theshapetuple. Consider removing this parameter to simplify the function signature.Apply this diff to the function signature and update all call sites:
def create_per_head_varying_kv( shape: Tuple[int, ...], num_heads: int, - head_dim: int, dtype: torch.dtype, device: str, ) -> torch.Tensor:Then update call sites (lines 453, 460, 595, 602, 728, 735) by removing the
head_dimargument.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
csrc/batch_prefill_fp8_sm90.cu(3 hunks)csrc/batch_prefill_sm90.cu(1 hunks)csrc/page.cu(0 hunks)flashinfer/page.py(0 hunks)flashinfer/prefill.py(21 hunks)flashinfer/sparse.py(2 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(7 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(2 hunks)tests/attention/test_hopper.py(3 hunks)tests/attention/test_hopper_fp8_attention.py(3 hunks)
💤 Files with no reviewable changes (2)
- csrc/page.cu
- flashinfer/page.py
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
get_seq_lens(178-201)flashinfer/utils.py (2)
canonicalize_torch_dtype(241-249)is_float8(158-159)
tests/attention/test_hopper_fp8_attention.py (2)
flashinfer/utils.py (1)
is_sm90a_supported(526-528)flashinfer/prefill.py (9)
plan(1549-1937)plan(2507-2804)run(1968-1980)run(1983-1995)run(1998-2222)run(2834-2844)run(2847-2857)run(2860-3017)BatchPrefillWithPagedKVCacheWrapper(1260-2254)
🪛 Ruff (0.14.7)
tests/attention/test_hopper_fp8_attention.py
19-19: Avoid specifying long messages outside the exception class
(TRY003)
350-350: Unused function argument: head_dim
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (15)
flashinfer/prefill.py (2)
418-423: FP8 scale parameters added correctly.The optional scale tensors for FP8 quantization are properly added with
Nonedefaults, maintaining backward compatibility.
2960-2969: Approve FA3 FP8 bypass of conversion.Correctly skips the FP8→FP16 conversion when the FA3 backend handles FP8 natively.
tests/attention/test_hopper.py (3)
210-210: Good addition of page_size=16 test coverage.This extends test coverage to non-unit page sizes, which is important given the page-based addressing changes in this PR.
270-270: Correct removal of padding from kv_indices.The previous version added 256 extra indices as padding, which was unnecessary. The new version generates exactly
batch_size * num_pages_per_requestindices, matching the actual number of pages allocated.
437-437: Test data adjustment for multi-item scoring.The padding distribution change (
* 17and* 5instead of previous values) aligns with the test case'stoken_pos_in_items_len=97and the actual data lengths.csrc/batch_prefill_sm90.cu (1)
221-238: Page stride fields and validation correctly implemented.The page stride extraction and runtime validation are well-implemented:
- Correctly extracts
stride(0)for page-level stride in both NHD and HND layouts- Runtime checks ensure K and V have matching strides, which is required for the sparse mainloop optimization
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (4)
337-346: Page stride parameters correctly wired to CollectiveMainloop.The mainloop arguments now include
k_page_stride,v_page_stride,kv_indices, andpage_size, enabling sparse paged KV cache support.
464-477: CTA_KV optimization for HEAD_DIM=256 is well-documented.The reduction from CTA_KV=128 to CTA_KV=64 for sparse paged loading with HEAD_DIM=256 is explained in the comment. The constraint that FP8 transpose requires minimum 64x64 blocks justifies not reducing further.
479-552: Ragged KV cache function correctly mirrors paged structure.The new
BatchFP8PrefillWithRaggedKVCacheKernelTraitsDispatchedfollows the same pattern as the paged version but usesFP8CollectiveMainloopinstead ofFP8SparseCollectiveMainloop, and uses standard gmem layouts instead of page-based addressing.
554-594: Ragged dispatch function uses TMA for all head dimensions.The ragged path uses
USE_TMA_LOAD_KV=truefor all head dimensions (64, 128, 256), while the paged path usesUSE_TMA_LOAD_KV=false. This is correct since ragged KV has contiguous memory that can benefit from TMA, whereas paged KV has non-contiguous pages.csrc/batch_prefill_fp8_sm90.cu (2)
86-173: Ragged KV cache SM90 run implementation looks correct.The function properly:
- Validates LSE tensor dimensions
- Populates RaggedParams with correct strides for both NHD and HND layouts
- Uses static assertions for head dimension and dtype consistency
- Dispatches to the appropriate kernel based on scheduler configuration
231-242: Page stride fields correctly added for FP8 paged path.Consistent with the changes in
batch_prefill_sm90.cu, the FP8 path now also stores page strides for sparse paged KV cache support.include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (1)
246-329: LGTM: Well-designed FA3-style prefetch optimization.The prefetch offset calculation and load implementation with shuffle-based offset sharing is well-designed. The detailed comments explain the group organization, thread distribution, and double-buffering strategy clearly. The use of
divmodfor efficient page table lookup and shuffle for offset sharing across threads is appropriate for FA3.tests/attention/test_hopper_fp8_attention.py (2)
250-811: Excellent test coverage for FP8 attention paths.The new tests comprehensively cover FP8 quantization scenarios:
- Ragged KV cache (
test_batch_prefill_ragged)- Paged KV cache (
test_batch_prefill_paged)- GQA support (
test_batch_prefill_paged_gqa) - addresses past review comment- Per-tensor vs per-head scale types (
test_batch_prefill_paged_scale_types) - addresses past review commentThe use of
create_per_head_varying_kvto surface head-mapping bugs is a clever testing strategy.
12-19: FP8 min/max values are correct per NVIDIA specification.The hard-coded values match the NVIDIA FP8 format specification:
float8_e4m3fnmax finite normal is ±448, andfloat8_e5m2max finite normal is ±57,344. No changes needed.
| # For FP8, append scale tensors | ||
| if is_float8(q): | ||
| run_args.extend(list(args)) # scale_q, scale_k, scale_v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FP8 scale tensors appended without validation.
When is_float8(q) is true, args is appended directly. There's no check that args actually contains exactly 3 scale tensors. If the user forgets to pass scales, this will silently pass incorrect data to the kernel.
Consider adding validation:
# For FP8, append scale tensors
if is_float8(q):
+ if len(args) < 3:
+ raise ValueError(
+ "FP8 attention requires scale_q, scale_k, scale_v tensors to be passed"
+ )
run_args.extend(list(args)) # scale_q, scale_k, scale_v🤖 Prompt for AI Agents
In flashinfer/prefill.py around lines 3011 to 3013, the code appends args when
is_float8(q) without validating that args contains exactly the three FP8 scale
tensors; add a guard that checks len(args) == 3 and that each element is a
tensor of expected dtype/shape (or at least a tensor-like object), and raise a
clear ValueError/TypeError if the check fails so the kernel never receives
missing/malformed scale arguments; update the error message to indicate which
scale(s) are missing or invalid.
|
/bot run |
|
[CANCELING] Pipeline #39546502: canceled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer/prefill.py (2)
1702-1704: Documentation improvement needed.The docstring mentions "For FP8 inputs, this should typically be set to torch.float16" but bfloat16 is also a valid option. Consider updating to: "For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16."
This was also noted in a past review comment.
2222-2229: Add validation for FP8 scale tensors.A past review flagged that when
is_float8(q)is true, the code should validate thatargscontains exactly 3 scale tensors. Currently, iflen(args) < 3, the scales remainNoneand are silently passed to the kernel, which may cause incorrect results or crashes.Consider adding validation as suggested:
# Extract FP8 scale tensors from *args if q is FP8 fp8_scale_q = None fp8_scale_k = None fp8_scale_v = None if is_float8(q) and len(args) >= 3: + if len(args) < 3: + raise ValueError( + "FP8 attention requires scale_q, scale_k, scale_v tensors to be passed" + ) fp8_scale_q = args[0] fp8_scale_k = args[1] fp8_scale_v = args[2]
🧹 Nitpick comments (3)
include/flashinfer/attention/hopper/variants.cuh (1)
28-35: Consider consolidating SFINAE trait definitions.Both this file and
variant_helper.cuhdefineDEFINE_HAS_MEMBERtraits. While this file adds traits formaybe_scale_v/q/kandscale_v/q/k_scalar, andvariant_helper.cuhdefinesv_scale, having trait definitions split across files may lead to maintenance overhead.Consider centralizing all SFINAE trait definitions in a single header (e.g.,
utils.cuhor a dedicatedtraits.cuh) to improve discoverability and reduce duplication risk.tests/attention/test_hopper_fp8_attention.py (2)
347-370: Unusedhead_dimparameter.The
head_dimparameter is declared but never used in the function body. Consider removing it or using it for validation.def create_per_head_varying_kv( shape: Tuple[int, ...], num_heads: int, - head_dim: int, dtype: torch.dtype, device: str, ) -> torch.Tensor:Note: This would require updating all call sites (lines 455-468, 462-468, 598-604, 605-611, 731-737, 738-744).
520-520: Remove debug print statement.This print statement appears to be leftover debugging code that should be removed for cleaner test output.
o_fp8 = wrapper_fp8.run(q_fp8, (paged_k_fp8, paged_v_fp8), s_q, s_k, s_v) - print(o_ref, o_fp8)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja(1 hunks)flashinfer/jit/attention/modules.py(2 hunks)flashinfer/prefill.py(29 hunks)include/flashinfer/attention/hopper/mainloop_mma.cuh(1 hunks)include/flashinfer/attention/hopper/variants.cuh(3 hunks)include/flashinfer/attention/variant_helper.cuh(2 hunks)tests/attention/test_hopper_fp8_attention.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
get_seq_lens(178-201)flashinfer/utils.py (2)
canonicalize_torch_dtype(241-249)is_float8(158-159)
tests/attention/test_hopper_fp8_attention.py (1)
flashinfer/utils.py (1)
is_sm90a_supported(526-528)
🪛 Ruff (0.14.7)
tests/attention/test_hopper_fp8_attention.py
19-19: Avoid specifying long messages outside the exception class
(TRY003)
350-350: Unused function argument: head_dim
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (17)
include/flashinfer/attention/variant_helper.cuh (2)
23-28: LGTM! Clean SFINAE-based trait detection for optional v_scale.The pattern using
DEFINE_HAS_MEMBERand theget_v_scalehelper enables backward-compatible optional scaling.
85-99: LGTM! Output scaling logic is correct.The
get_v_scalehelper properly defaults to1.0fwhenv_scaleis not present inParams, maintaining backward compatibility while enabling FP8 scaling paths.flashinfer/jit/attention/modules.py (2)
531-554: LGTM! Parameter lists correctly expanded for FP8 scale handling.The non-FP8 path adds
maybe_scale_v/scale_v_scalar, while the FP8 path adds all three scale pairs (maybe_scale_q/k/vandscale_q/k/v_scalar). The tensor/scalar counts and dtypes are correctly aligned.
1016-1048: LGTM! Batch prefill module parameters mirror single prefill updates.The FA3 batch prefill path correctly mirrors the single prefill parameter expansions for both non-FP8 and FP8 paths.
csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja (1)
1-15: LGTM! Template instantiation structure is correct.The template properly generates explicit instantiations for both
SAME_SCHEDULER_FOR_ALL_HEADSvariants. The namespace closing syntax has been corrected per the previous review.include/flashinfer/attention/hopper/variants.cuh (3)
37-40: LGTM! Null-safe scale accessor.The
get_scalehelper correctly handles the case wheretensor_ptrisnullptrby falling back toscalar_val.
72-91: LGTM! StandardAttention correctly initializes scale_pv.The structured binding extracts block coordinates, and
scale_pvis initialized using theget_v_scalehelper with proper fallback behavior.
117-131: LGTM! StandardFP8Attention scale computation is correct.The FP8 attention variant properly computes:
p_scalefrom the FP8 type's max valuescale_pv = v_scale / p_scalefor PV dequantizationsm_scale_with_qk_log2incorporating q_scale, k_scale, and sm_scaleThis correctly handles the FP8 quantization/dequantization flow.
include/flashinfer/attention/hopper/mainloop_mma.cuh (1)
319-319: Thread scale_pv into finalization for FP8 dequantization.This change passes
variant.scale_pvto thefinalizecall, enabling fused PV dequantization as part of online softmax finalization. Ensure theAttentionUpdater::finalizemethod signature accepts this parameter.tests/attention/test_hopper_fp8_attention.py (3)
250-344: LGTM!The ragged KV cache FP8 test is well-structured with clear setup, reference computation, and MSE validation. The variable length sequences provide good coverage.
528-545: LGTM!This test addresses the previously requested GQA coverage. The parameterization with different head ratios (32:8, 16:4, 8:2) provides good coverage for head mapping logic verification.
668-691: LGTM!This test addresses the previously requested coverage for both per-tensor and per-head scale types. The test correctly broadcasts per-tensor scales to per-head format before passing to the kernel.
flashinfer/prefill.py (5)
69-86: LGTM!The
_split_scale_paramhelper cleanly handles the three cases (None, tensor, scalar) and provides a consistent interface for FP8 scale handling throughout the module.
310-349: LGTM!The FP8 scale handling correctly uses
_split_scale_paramto decompose scales into tensor and scalar components, supporting both per-head tensor scales and scalar scales.
446-451: Consider validating FP8 scale tensor consistency.The FP8 detection relies solely on
scale_q is not None. If a user passesscale_qbut omitsscale_korscale_v,_split_scale_paramwill return(None, 1.0)which may silently produce incorrect results instead of raising an error.Consider adding validation:
# Check if FP8 by presence of scale tensors is_fp8 = scale_q is not None + if is_fp8 and (scale_k is None or scale_v is None): + raise ValueError( + "FP8 attention requires all scale tensors (scale_q, scale_k, scale_v)" + )
2144-2154: LGTM!The output dtype handling correctly uses
_cached_o_data_typewhen available, with a safe fallback toq.dtypefor backward compatibility.
3056-3058: Add validation for FP8 scale tensors in ragged run.Similar to the paged run path, this code appends
argsas FP8 scales without validating that the required 3 scale tensors are present. Ifis_float8(q)butargsis empty or has fewer than 3 elements, the kernel may receive incorrect arguments.Consider adding validation:
# For FP8, append scale tensors if is_float8(q): + if len(args) < 3: + raise ValueError( + "FP8 attention requires scale_q, scale_k, scale_v tensors to be passed" + ) run_args.extend(list(args)) # scale_q, scale_k, scale_v
|
/bot run |
|
[CANCELING] Pipeline #39670861: canceled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
include/flashinfer/attention/hopper/variants.cuh (1)
21-82: Scale helper and SFINAE layer is sound; be aware of the “all-or-nothing” detectionThe
DEFINE_HAS_MEMBERtraits plusget_scale/get_{v,q,k}_scalehelpers provide a clean way to support optional per-head scales while defaulting to 1.0f when scale fields are absent, which keeps non-FP8 / legacy paths unchanged. Note that each getter requires bothmaybe_scale_*andscale_*_scalarto exist; if a futureAdditionalParamsdefines only one of these, it will silently fall back to 1.0f. That’s reasonable, but worth keeping in mind when extendingAdditionalParamstypes.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/jit/attention/variants.py(1 hunks)include/flashinfer/attention/hopper/mainloop_mma.cuh(2 hunks)include/flashinfer/attention/hopper/variants.cuh(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/jit/attention/variants.py (1)
145-159: FA3 AttentionSink: scale_pv initialization looks consistentDestructuring
block_coordto exposekv_head_idxand initializingscale_pvviaget_v_scale(params.additional_params, kv_head_idx)matches the new Hopper variant pattern and keeps behavior no-op when v-scales are absent (fallback 1.0f). Within this snippetscale_pvis only stored; assuming downstream code reads it (e.g., via a helper similar toget_variant_scale_pv), this is a clean extension with no behavior change for setups without per-head v-scales.include/flashinfer/attention/hopper/mainloop_mma.cuh (1)
15-16: Plumbing pv-scale into finalize is correct; ensure all updaters support the new signatureIncluding
"variants.cuh"here and callingattention_updater.finalize(tSrS, get_variant_scale_pv(variant))correctly threads variant-specificscale_pvinto the online softmax path while preserving old behavior for variants without ascale_pvfield (fallback 1.0f). Please double-check that everyAttentionUpdaterused with thismma_f16now implements a compatiblefinalize(Tensor0&, float)(or has a defaulted second parameter) so no template instantiation breaks for non-sink / non-FP8 paths.Also applies to: 321-321
include/flashinfer/attention/hopper/variants.cuh (1)
83-165: Variant scale_pv wiring aligns with FP8 and non-FP8 semantics; confirm no remaining ODequantize usersThe additions of
scale_pvtoStandardAttentionandLogitsSoftCap(initialized viaget_v_scale) and toStandardFP8Attention(asv_scale / p_scale) line up with the newget_variant_scale_pv(variant)+finalize(..., pv_scale)path and keep behavior unchanged when no scales are provided (helpers return 1.0f). ForStandardFP8Attention, usingq_scale * k_scaleto fold Q/K dequantization intosm_scale_with_qk_log2and moving PV dequantization intofinalizeviascale_pvmatches the P-quantization setup.Given
ODequantizeis now intentionally a no-op with a comment about fusing into finalize, please ensure there are no remaining call sites that still expect it to perform the output dequantization directly.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
flashinfer/prefill.py (2)
2145-2154: Run path now respects cached output dtype; ensure FP8 scale args are clearly specified andq_scaledocumented.
- Using
_cached_o_data_type(falling back toq.dtypeviagetattr) for both allocation and dtype checks ofoutmakes the run phase honor theo_data_typechosen at plan time, which is important for FP8 with FP16 output.- For non‑JIT FA3 FP8 paged kernels, the convention
if is_float8(q) and len(args) >= 3: fp8_scale_q, fp8_scale_k, fp8_scale_v = args[:3]is reasonable, but it’s implicit. Consider documenting in the wrapper docstring that, in FP8 mode, the first three extra positional args are interpreted asscale_q,scale_k,scale_vand any additional args are ignored. Making them explicit keyword arguments on this API surface would be even clearer long‑term.Also,
q_scaleis now used to fold intosm_scalehere but is still not described in therundocstring, which has been called out previously.Also applies to: 2222-2242, 2255-2256
2977-2980: Raggedrun: output dtype wiring is good; still missing validation for FP8 scale tensors in*args.
- Using
_cached_o_data_typefor ragged outputs (both allocation and checks) is aligned with the newplanAPI and is necessary for FP8 inputs with FP16 outputs.- The updated cast guard
if is_float8(q) and self._backend != "fa3": ...correctly skips the FP8→FP16 conversion only for FA3, which now has native FP8 support.The remaining issue is the FP8 scale handling at the bottom:
# For FP8, append scale tensors if is_float8(q): run_args.extend(list(args)) # scale_q, scale_k, scale_vAs previously noted, there is still no validation that:
argscontains exactly the three expected scale tensors, and- each element has a reasonable type/shape for a per‑head scale.
A malformed or incomplete
argstuple will silently propagate garbage into the kernel arguments.Consider tightening this by, for the non‑JIT path:
- # For FP8, append scale tensors - if is_float8(q): - run_args.extend(list(args)) # scale_q, scale_k, scale_v + # For FP8, append (scale_q, scale_k, scale_v) explicitly + if is_float8(q): + if len(args) != 3: + raise ValueError( + f"FP8 ragged prefill expects 3 scale tensors " + f"(scale_q, scale_k, scale_v), got {len(args)}" + ) + scale_q, scale_k, scale_v = args + for name, scale in (("scale_q", scale_q), ("scale_k", scale_k), ("scale_v", scale_v)): + if not isinstance(scale, torch.Tensor): + raise TypeError(f"{name} must be a torch.Tensor, got {type(scale)}") + run_args.extend([scale_q, scale_k, scale_v])This keeps the interface the same for correct callers but fails fast and clearly for misconfigured FP8 runs.
Also applies to: 2983-2987, 3005-3015, 3056-3058
🧹 Nitpick comments (3)
flashinfer/prefill.py (3)
69-86: Helper_split_scale_parambehavior is sound; consider tightening typing/validation.The tensor/scalar split logic looks correct and is reused consistently for FA3 FP16/FP8, but the helper currently accepts any non‑tensor object and blindly calls
float(scale). You could make failures clearer and improve static checking by:
- Adding a type hint, e.g.
scale: Optional[Union[torch.Tensor, float, int]] -> Tuple[Optional[torch.Tensor], float].- Raising a
TypeErrorwith a clear message ifscaleis of an unsupported type instead of relying onfloat(scale)to fail.
446-452: FA3 ragged FP8 path: clarify feature support and tighten FP8 gating.The new FA3 FP8 branch correctly reuses
_split_scale_paramand routes scale tensors/scalars separately, but two details are worth tightening:
FP8 mode detection.
is_fp8 = scale_q is not Nonemeans merely passing a non‑Nonescale_qflips to the FP8 kernel, even ifqis not actually float8. It would be safer to gate onis_float8(q)(or both conditions) to avoid accidentally invoking the FP8 variant with FP16 inputs.Dropped mask / multi‑item args in FP8 branch.
In the FP8 case you no longer threadmaybe_custom_mask,maybe_mask_indptr,maybe_alibi_slopes,maybe_prefix_len_ptr,maybe_token_pos_in_items_ptr,maybe_max_item_len_ptr,logits_soft_cap,rope_scale,rope_theta, ortoken_pos_in_items_lentoragged_run_func. If FA3 FP8 truly doesn’t support custom masks or multi‑item scoring yet, it would be safer to:
- Either raise a
NotImplementedErrorwhen those arguments are non‑None in FP8 mode, or- Document explicitly that these features are unsupported for FA3 FP8 ragged prefill and are ignored.
This avoids silent behavior differences between FP16 and FP8 runs.
Also applies to: 481-508, 510-513, 531-535
1617-1617: Newo_data_typeandget_seq_lensusage in pagedplanare reasonable; watch kv_lens scope for TRTLLM.
- Introducing
o_data_type(defaulting toq_data_typethen canonicalized) and caching it as_cached_o_data_typeis a good way to support FP8 inputs with FP16 outputs while keeping FP16/BF16 behavior unchanged.- Switching the KV‑length computation to
get_seq_lens(paged_kv_indptr_host, paged_kv_last_page_len_host, page_size)and copying into_kv_lens_bufferis consistent with the existing formula and should help with robust sizing for large block counts.get_module_argsnow passeso_data_typeintoget_batch_prefill_module, so the compiled kernel sees the intended output dtype.One small edge case: in the TRTLLM‑GEN branch you rely on
kv_lens_arr_host(forblocks_per_seq), which is only defined in themax_sequence_kv is Nonepath. The docs saymax_sequence_kvis required only for the cuDNN backend, but if a caller ever provides it together with a TRTLLM‑GEN backend,kv_lens_arr_hostwill be undefined. Consider either:
- Guarding with an explicit error if
self._backend == "trtllm-gen" and max_sequence_kv is not None, or- Recomputing
kv_lens_arr_hostfromseq_lensin that branch.Also applies to: 1702-1705, 1709-1710, 1719-1720, 1756-1758, 1805-1814, 1889-1890, 1907-1908
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/prefill.py(32 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/prefill.py (2)
flashinfer/page.py (1)
get_seq_lens(178-201)flashinfer/utils.py (2)
canonicalize_torch_dtype(241-249)is_float8(158-159)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/prefill.py (2)
311-349: FA3 single‑prefill scale splitting looks consistent with new kernel API.For FA3, splitting
scale_v(FP16) andscale_q/scale_k/scale_v(FP8) into(tensor, scalar)pairs via_split_scale_paramand passing them as separate arguments matches the intended “per‑head tensor vs scalar factor” design and keeps non‑FA3 backends untouched. I don’t see correctness issues here.
692-747: Paged FA3 FP16/FP8 scale handling via_split_scale_paramlooks correct.For the paged backend, splitting
scale_vfor FP16 andscale_q/scale_k/scale_vfor FP8 into tensor/scalar components and threading them intopaged_run_funckeeps the call signature aligned with the FA3 kernels while preserving existing behavior for FA2/TRTLLM/CuDNN. I don’t see functional issues in this hunk.
| o_data_type: Optional[Union[str, torch.dtype]] = None, | ||
| non_blocking: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass dtype_o parameter to fmha_varlen for CUTLASS backend to honor the requested output dtype.
Currently, fmha_varlen() is called without the dtype_o parameter (line 3049), so it defaults to v.dtype. This differs from the FA2/FA3 path (line 3055), which explicitly passes dtype_o=self._cached_o_data_type. Since out is allocated with _cached_o_data_type (line 3029), there's a potential dtype mismatch when o_data_type differs from v.dtype in the CUTLASS path. Either enforce o_data_type == v.dtype for CUTLASS, or pass dtype_o=self._cached_o_data_type to fmha_varlen to ensure consistency with the FA2/FA3 path and the buffer allocation.
🤖 Prompt for AI Agents
In flashinfer/prefill.py around lines 2572-2573 (see related calls around lines
3029, 3049, 3055), the CUTLASS fmha_varlen call omits the dtype_o argument
causing a potential dtype mismatch with the preallocated out buffer; modify the
CUTLASS path call to pass dtype_o=self._cached_o_data_type to fmha_varlen so the
output dtype matches the allocated out tensor (or alternatively add an explicit
assert that self._cached_o_data_type == v.dtype before calling fmha_varlen if
you intend to enforce identical dtypes).
📌 Description
This PR refactors the out-dated fa3 codebase, more specifically, for page_size>1, the page offset calculation is performed inside the kernel, without the need of a standalone function call to block_sparse_indices_to_vector_sparse_offsets, and optimize the kv_offset calculation with prefetching and shuffling.
This PR also fixes the failed unittest on hopper.
However, the FA3 structure in our codebase is still terrible outdated without important features such as
IntraWGOverlapandRescaleOBeforeGemm, will follow up soon in a later PR.🔍 Related Issues
This PR should fixes #1647
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
New Features
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.