-
Notifications
You must be signed in to change notification settings - Fork 593
refactor: update fa3 codebase [part 2] #2192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughPR introduces Hopper kernel optimizations for attention mechanisms including FP8 transposed V loading infrastructure, L2-aware scheduling with longest-processing-time (LPT) support, and a new barrier-optimized TMA pipeline class. Adds benchmarking script for FlashInfer/FA3 performance comparison. Modifies mainloop masking patterns, epilogue operations, and tile scheduling to support improved causal attention and FP8 quantization paths. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request continues the refactoring of the FlashAttention-3 (FA3) codebase within FlashInfer, with a primary focus on enhancing the efficiency and performance of FP8 operations. It integrates advanced memory access patterns and scheduling strategies to better leverage GPU architecture, aiming to close the performance gap with native FA3 implementations. The changes are validated through a new, detailed benchmarking suite. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant performance optimizations to the FA3 codebase, particularly for FP8 data types. The changes include a new comprehensive benchmark script, L2-aware scheduling, a workaround for a Cutlass performance regression, Longest-Processing-Time-First (LPT) scheduling, and FA3-style same-shape transpose for FP8. The changes are well-commented and demonstrate a deep understanding of CUDA and Cutlass for performance tuning. My review focuses on the new benchmark script, where I've identified a couple of inefficiencies in the test setup code. The CUDA-level changes appear to be solid optimizations.
| kv_indptr = torch.tensor( | ||
| [0] | ||
| + [ | ||
| sum((kv_lens[i] + page_size - 1) // page_size for i in range(j + 1)) | ||
| for j in range(batch_size) | ||
| ], | ||
| dtype=torch.int32, | ||
| device=device, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The creation of kv_indptr has a quadratic time complexity (O(N^2)) with respect to batch_size due to the nested loop structure (sum inside a list comprehension). For a large batch_size, this can be very slow. This can be computed in linear time. This pattern is repeated on lines 571-579.
num_pages_per_seq = torch.tensor([(kv_len + page_size - 1) // page_size for kv_len in kv_lens], dtype=torch.int32, device=device)
kv_indptr = torch.nn.functional.pad(torch.cumsum(num_pages_per_seq, 0), (1, 0))| cu_seqlens = torch.tensor( | ||
| [0] + list(torch.cumsum(torch.tensor(qo_lens), 0).numpy()), | ||
| dtype=torch.int32, | ||
| device=device, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The creation of cu_seqlens is inefficient. It involves creating a tensor on the CPU, performing a cumulative sum, converting it to a NumPy array, then to a list, and finally creating a new tensor on the GPU. This can be performed more efficiently and directly on the GPU to avoid unnecessary CPU-GPU data transfers and conversions. This pattern is repeated in other parts of the script (e.g., lines 199-203, 291-295, etc.).
cu_seqlens = torch.nn.functional.pad(torch.cumsum(torch.tensor(qo_lens, device=device, dtype=torch.int32), dim=0), (1, 0))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (5)
include/flashinfer/attention/hopper/sm90_pipeline_no_cluster.cuh (2)
18-24: Well-documented workaround for Cutlass regression.The comment clearly explains the rationale for this custom pipeline class. However, consider adding the Cutlass version check or a link to the relevant Cutlass issue/commit for future maintainability.
Consider adding a reference to the specific Cutlass commit or issue that introduced this regression:
// As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads // signaling the barrier during consumer_release. This causes a perf regression in FA3 // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. +// TODO: Remove this workaround once upstream Cutlass provides a fix. // // Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0
36-53: Consider adding static_assert or runtime check for num_consumers alignment.The documented assumption (
params.num_consumers % NumThreadsPerWarpGroup == 0) is critical for correctness but not enforced. The calculation at line 43-44 uses ceiling division which could mask misalignment issues.static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params) { + // Enforce documented assumption + assert(params.num_consumers % cutlass::NumThreadsPerWarpGroup == 0 && + "num_consumers must be a multiple of NumThreadsPerWarpGroup"); int warp_idx = cutlass::canonical_warp_idx_sync(); bool is_initializing_warp = (warp_idx == 0);benchmarks/bench_fa3_comparison.py (3)
78-83: Consider capturing loop variables explicitly in lambdas.While the current code works correctly (variables are used immediately within the same iteration), Ruff flags B023 warnings because lambdas reference loop variables. This pattern repeats throughout the file and violates linting best practices.
Apply this pattern to explicitly capture variables (example for this section, apply similarly to lines 147, 218, 319, 333, 456, 497, 635, 687):
- fi_time = bench_fn(lambda: wrapper.run(q, k, v)) - fa3_time = bench_fn( - lambda: fa3_varlen_func( - q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len, causal=True - ) - ) + fi_time = bench_fn(lambda w=wrapper, q_=q, k_=k, v_=v: w.run(q_, k_, v_)) + fa3_time = bench_fn( + lambda q_=q, k_=k, v_=v, cu=cu_seqlens, sl=seq_len: fa3_varlen_func( + q_, k_, v_, cu, cu, sl, sl, causal=True + ) + )This makes the variable capture explicit and silences the linter warnings.
128-132: Consider using iterable unpacking for cleaner code.Line 129 (and similar patterns at lines 200, 292, 428, 587) uses list concatenation which can be simplified with iterable unpacking.
Apply this pattern:
- cu_seqlens = torch.tensor( - [0] + list(torch.cumsum(torch.tensor(qo_lens), 0).numpy()), - dtype=torch.int32, - device=device, - ) + cu_seqlens = torch.tensor( + [0, *torch.cumsum(torch.tensor(qo_lens), 0).tolist()], + dtype=torch.int32, + device=device, + )This is more idiomatic Python and avoids the intermediate list conversion.
302-349: Consider logging exceptions for debugging.The broad exception handling is appropriate for a benchmarking script to ensure graceful degradation. However, silently swallowing exceptions (lines 323, 346) makes debugging difficult when implementations fail.
Add minimal logging:
except Exception: + # Silently skip if FP8 not supported fi_time = None fi_tflops = NoneOr for better debugging:
- except Exception: + except Exception as e: + print(f" (FlashInfer FP8 skipped: {type(e).__name__})") fi_time = None fi_tflops = NoneThis pattern applies to similar exception handlers at lines 346, 460, 508, 639, 701.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
benchmarks/bench_fa3_comparison.py(1 hunks)include/flashinfer/attention/hopper/kernel_traits.cuh(2 hunks)include/flashinfer/attention/hopper/mainloop.cuh(2 hunks)include/flashinfer/attention/hopper/mainloop_mma.cuh(4 hunks)include/flashinfer/attention/hopper/prefill_sm90.cuh(1 hunks)include/flashinfer/attention/hopper/quantization/epilogue.cuh(2 hunks)include/flashinfer/attention/hopper/quantization/kernel_traits.cuh(4 hunks)include/flashinfer/attention/hopper/quantization/mainloop_load.cuh(8 hunks)include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh(2 hunks)include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh(2 hunks)include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh(2 hunks)include/flashinfer/attention/hopper/sm90_pipeline_no_cluster.cuh(1 hunks)include/flashinfer/attention/hopper/tile_scheduler.cuh(3 hunks)include/flashinfer/attention/hopper/utils.cuh(1 hunks)include/flashinfer/attention/scheduler.cuh(3 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/attention/hopper/kernel_traits.cuhinclude/flashinfer/attention/hopper/quantization/mainloop_mma.cuhinclude/flashinfer/attention/hopper/quantization/epilogue.cuh
🧬 Code graph analysis (1)
benchmarks/bench_fa3_comparison.py (2)
flashinfer/testing/utils.py (1)
bench_gpu_time_with_cupti(646-878)flashinfer/prefill.py (2)
BatchPrefillWithRaggedKVCacheWrapper(2316-3093)BatchPrefillWithPagedKVCacheWrapper(1305-2299)
🪛 Ruff (0.14.8)
benchmarks/bench_fa3_comparison.py
78-78: Function definition does not bind loop variable wrapper
(B023)
78-78: Function definition does not bind loop variable q
(B023)
78-78: Function definition does not bind loop variable k
(B023)
78-78: Function definition does not bind loop variable v
(B023)
81-81: Function definition does not bind loop variable q
(B023)
81-81: Function definition does not bind loop variable k
(B023)
81-81: Function definition does not bind loop variable v
(B023)
81-81: Function definition does not bind loop variable cu_seqlens
(B023)
81-81: Function definition does not bind loop variable cu_seqlens
(B023)
81-81: Function definition does not bind loop variable seq_len
(B023)
81-81: Function definition does not bind loop variable seq_len
(B023)
129-129: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
147-147: Function definition does not bind loop variable wrapper
(B023)
147-147: Function definition does not bind loop variable q
(B023)
147-147: Function definition does not bind loop variable k
(B023)
147-147: Function definition does not bind loop variable v
(B023)
150-150: Function definition does not bind loop variable q
(B023)
150-150: Function definition does not bind loop variable k
(B023)
150-150: Function definition does not bind loop variable v
(B023)
150-150: Function definition does not bind loop variable cu_seqlens
(B023)
150-150: Function definition does not bind loop variable cu_seqlens
(B023)
150-150: Function definition does not bind loop variable seq_len
(B023)
150-150: Function definition does not bind loop variable seq_len
(B023)
192-192: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
200-200: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
218-218: Function definition does not bind loop variable wrapper
(B023)
218-218: Function definition does not bind loop variable q
(B023)
218-218: Function definition does not bind loop variable k
(B023)
218-218: Function definition does not bind loop variable v
(B023)
221-221: Function definition does not bind loop variable q
(B023)
221-221: Function definition does not bind loop variable k
(B023)
221-221: Function definition does not bind loop variable v
(B023)
221-221: Function definition does not bind loop variable cu_seqlens
(B023)
221-221: Function definition does not bind loop variable cu_seqlens
(B023)
221-221: Function definition does not bind loop variable max_seqlen
(B023)
221-221: Function definition does not bind loop variable max_seqlen
(B023)
292-292: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
319-319: Function definition does not bind loop variable wrapper
(B023)
319-319: Function definition does not bind loop variable q_fp8
(B023)
319-319: Function definition does not bind loop variable k_fp8
(B023)
319-319: Function definition does not bind loop variable v_fp8
(B023)
319-319: Function definition does not bind loop variable s_q
(B023)
319-319: Function definition does not bind loop variable s_k
(B023)
319-319: Function definition does not bind loop variable s_v
(B023)
323-323: Do not catch blind exception: Exception
(BLE001)
333-333: Function definition does not bind loop variable q_fp8
(B023)
334-334: Function definition does not bind loop variable k_fp8
(B023)
335-335: Function definition does not bind loop variable v_fp8
(B023)
336-336: Function definition does not bind loop variable cu_seqlens
(B023)
337-337: Function definition does not bind loop variable cu_seqlens
(B023)
338-338: Function definition does not bind loop variable seq_len
(B023)
339-339: Function definition does not bind loop variable seq_len
(B023)
346-346: Do not catch blind exception: Exception
(BLE001)
428-428: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
456-456: Function definition does not bind loop variable wrapper
(B023)
456-456: Function definition does not bind loop variable q_fp16
(B023)
456-456: Function definition does not bind loop variable kv_data_fp16
(B023)
460-460: Do not catch blind exception: Exception
(BLE001)
497-497: Function definition does not bind loop variable q_fa3
(B023)
498-498: Function definition does not bind loop variable k_cache_fa3
(B023)
499-499: Function definition does not bind loop variable v_cache_fa3
(B023)
500-500: Function definition does not bind loop variable cache_seqlens
(B023)
501-501: Function definition does not bind loop variable page_table
(B023)
508-508: Do not catch blind exception: Exception
(BLE001)
587-587: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
635-635: Function definition does not bind loop variable wrapper
(B023)
635-635: Function definition does not bind loop variable q_fp8
(B023)
635-635: Function definition does not bind loop variable kv_data_fp8
(B023)
635-635: Function definition does not bind loop variable s_q
(B023)
635-635: Function definition does not bind loop variable s_k
(B023)
635-635: Function definition does not bind loop variable s_v
(B023)
639-639: Do not catch blind exception: Exception
(BLE001)
687-687: Function definition does not bind loop variable q_fa3
(B023)
688-688: Function definition does not bind loop variable k_cache_fa3
(B023)
689-689: Function definition does not bind loop variable v_cache_fa3
(B023)
690-690: Function definition does not bind loop variable cache_seqlens
(B023)
691-691: Function definition does not bind loop variable page_table
(B023)
692-692: Function definition does not bind loop variable q_descale_fa3
(B023)
693-693: Function definition does not bind loop variable k_descale_fa3
(B023)
694-694: Function definition does not bind loop variable v_descale_fa3
(B023)
701-701: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (47)
include/flashinfer/attention/hopper/utils.cuh (1)
144-166: LGTM - FP8 output permutation utility is well-structured.The function correctly:
- Validates fragment layout with comprehensive static_asserts
- Uses
group_modesto flatten the tensor for easier iteration- Swaps adjacent elements to reverse the FP8 MMA register permutation before smem writes
The swap pattern
(1, j, 2*i) <-> (0, j, 2*i+1)correctly undoes the byte permutation applied during FP8 MMA.include/flashinfer/attention/hopper/mainloop.cuh (1)
204-206: LGTM - TMA cache hints for L2 optimization.The
EVICT_LASTcache hints are correctly applied to all K and V TMA loads. This hint marks loaded data as least-recently-used in L2, which is appropriate for attention K/V tiles that are typically consumed once per Q tile iteration. This complements the L2-aware scheduling changes inscheduler.cuh.Also applies to: 233-242, 249-251
include/flashinfer/attention/scheduler.cuh (3)
799-799: Verify the cost function coefficient change.The cost function changed from
2 * qo_len + kv_lento0.05 * qo_len + kv_len. This dramatically reduces the weight ofqo_lenrelative tokv_len(40x reduction).Please confirm this is intentional and aligned with FA3's cost model. The change makes the scheduler heavily prioritize KV-length in load balancing decisions.
932-997: L2-aware scheduling implementation looks correct.The implementation follows FA3's approach:
- Uses a conservative 8MB L2 estimate (appropriate for cache sharing scenarios)
- Stepped
nheads_in_l2values (16, 8, 4, 2, 1) match FA3's heuristics- Swizzle-based head grouping enables K/V reuse within L2 cache
- Inner loop order (q_tiles in LPT order, then heads) maximizes L2 hit rate
The
size_l2constant at line 935 is hardcoded. Consider making this configurable for different GPU architectures if needed in the future.
904-912: LGTM - KV tile size heuristics for L2 estimation.The
cta_tile_kvvalues are appropriately chosen based on head dimensions:
- 192 for non-causal HD128 (larger tiles possible without causal masking overhead)
- 64 for HD > 128 (smaller tiles to fit larger head dimensions)
These are used for L2 cache capacity estimation, aligning with the actual kernel tile sizes.
include/flashinfer/attention/hopper/mainloop_mma.cuh (4)
59-61: LGTM - Identity tensor hoisted outside loops.Creating the identity tensor and its partition once outside the loops is a good optimization, avoiding redundant tensor allocations in each iteration.
150-237: LGTM - FA3-style compile-time specialized mask functions.The refactoring introduces clean, specialized mask functions:
causal_mask_with_seqlen_fn: First iteration with boundary checkscausal_mask_fn: Masking iterations without redundant seqlen checksno_mask_fn: Main loop (no causal masking)swa_left_mask_fn: Sliding window left boundaryThis compile-time specialization enables the compiler to generate optimized code for each masking scenario without runtime branching overhead.
250-310: LGTM - Pipelined forward step with compile-time mask dispatch.The
fwd_steplambda correctly implements the pipelined attention computation:
- Wait for K and start QK gemm
- Rescale O (if not first iteration)
- Wait for V and start PV gemm
- Apply compile-time specialized mask
- Update attention state
- Release pipelines and advance
The loop passes
kv_tile_idx_decrement(kv_tile_idx)tofwd_stepbecause the function loads K for the next iteration while processing V from the current iteration - this correctly implements the software pipeline.
282-283: LGTM - Masking steps calculation.The
n_masking_stepscorrectly accounts for the causal diagonal:
- For causal:
ceil_div(CTA_Q, CTA_KV)iterations cover the diagonal region- MULTIITEMSCORING adds +1 for additional boundary handling
include/flashinfer/attention/hopper/quantization/prefill_sm90.cuh (2)
236-242: LGTM - Explicit SmemLayoutVt template parameter.Passing
CollectiveMainloop::SmemLayoutVtexplicitly tomma_fp8ensures the correct V transpose layout is used, aligning with the FA3-style layout infrastructure introduced across the kernel traits and mainloop components.
263-264: LGTM - LPT scheduling enabled for causal attention.Using
SingleTileScheduler</*LPT=*/CAUSAL>enables Longest Processing Time (LPT) scheduling when causal masking is active. This improves load balancing because causal tiles near the diagonal have significantly less work (fewer valid KV positions), and LPT scheduling helps distribute this uneven workload more evenly across SMs.include/flashinfer/attention/hopper/prefill_sm90.cuh (1)
301-302: LGTM! LPT scheduling for causal attention is a sensible optimization.The change enables Longest-Processing-Time-First scheduling when
CAUSAL=true. This improves load balancing since causal attention tiles have triangular workloads—later tiles process more KV tokens, so processing them first helps even out GPU utilization.include/flashinfer/attention/hopper/kernel_traits.cuh (2)
20-20: LGTM! Include for no-cluster pipeline variant.
114-118: VerifyPipelineTmaAsyncNoClusterclass exists and is properly implemented.
PipelineTmaAsyncNoClusteris not part of official CUTLASS and must be a custom FlashInfer class. Confirm:
- The class definition exists in the codebase with correct
NUM_STAGEStemplate parameter- The barrier signaling optimization (1 thread vs. all threads) is actually implemented
- The class properly integrates with Cutlass 3.6+ without regressions
include/flashinfer/attention/hopper/quantization/epilogue.cuh (2)
17-17: LGTM! Required include for epilogue detail utilities.
43-49: LGTM! Dynamic copy operation selection for FP8 accumulators.Using
sm90_get_smem_store_op_for_accumulatoris the correct approach—it automatically selects the appropriate SMEM store operation based on the output stride and data type, rather than hardcoding a specific copy operation. This improves correctness for FP8 paths where the optimal copy operation may differ from F16/BF16.include/flashinfer/attention/hopper/tile_scheduler.cuh (4)
16-19: LGTM! Clear documentation for the LPT scheduling feature.The comments effectively explain the purpose: reversing block indices so that tiles with more KV tokens (higher indices in causal attention) are processed first for better load balancing.
30-30: LGTM! Addednum_qo_tilesto Params for LPT reversal.
36-36: LGTM! Propagatingnum_qo_tilesfrom Arguments to Params.
66-71: LGTM! LPT reversal implementation is correct.The reversal formula
params.num_qo_tiles - 1 - q_tile_idxcorrectly maps:
blockIdx.x = 0→ highest tile index (most work in causal attention)blockIdx.x = num_qo_tiles - 1→ tile index 0 (least work)This ensures tiles with more KV tokens are processed first, improving GPU utilization through better load balancing.
include/flashinfer/attention/hopper/quantization/mainloop_load.cuh (8)
51-54: LGTM! Transposed stride definition for FA3-style V loading.The stride
(D, N, H)with(_1, int64_t, int64_t)correctly represents loading V with dimension D as the innermost (contiguous) dimension, enabling direct loading into MN-major shared memory.
71-79: LGTM! TMA_V updated for transposed V loading path.The tile shape change from
select<1,2>toselect<2,1>correctly reflects the transposed (D, N) tile shape for V, aligning with the MN-major SmemLayoutV.
188-199: LGTM! In-kernel V tensor and transpose source updated for transposed layout.The changes correctly:
- Create
mVwith transposed shape(D, N, H)for TMA tensor- Use
sVt_src(SmemLayoutV, MN-major) as the TMA destination and transpose source- Keep
sVt_tgt(SmemLayoutVt, K-major) as the transpose target for MMA consumption
210-212: LGTM! gV tile uses transposed shapeselect<2, 1>.Consistent with the TMA_V definition—tile shape is now (HEAD_DIM, CTA_KV) instead of (CTA_KV, HEAD_DIM).
272-272: LGTM! Transpose source updated tosVt_src.
297-297: LGTM! Consistent transpose source usage in pipeline loop.
319-319: LGTM! Consistent transpose source in final loop iteration.
132-143: Verify stride construction correctness for transposed V tensor.The transposed layout construction looks correct conceptually: swapping shape and stride elements to create a (D, N, H) layout from the original (N, D, H). However, this relies on
layout_Vhaving the expected stride pattern wherestride_Dis 1 (contiguous head dimension).Verify that the input
layout_Valways hasstride_D = 1(row-major within each head) as assumed by this transpose:#!/bin/bash # Search for how layout_V is constructed/passed to verify stride assumptions rg -n "layout_V" --type cpp -B2 -A2 | head -60 # Check get_gmem_layout to see what strides are typically used ast-grep --pattern $'get_gmem_layout($_, $_, $_, $_, $_)'include/flashinfer/attention/hopper/quantization/mainloop_mma.cuh (3)
17-22: LGTM! Clean template parameterization for layout flexibility.The introduction of
SmemLayoutVt_as a template parameter with clear documentation enables the mainloop to support both TMA (FA3-style) and sparse paths with their respective layouts.
133-136: Verify causal masking step count calculation.The
n_masking_stepscalculation assumesceil_div(CTA_Q, CTA_KV)is correct for causal masking. With different tile sizes, ensure this formula correctly determines the number of tiles requiring masking.
23-40: Verify all call sites are updated for the new template parameter.The function signature includes a new template parameter
SmemLayoutVt_. Ensure all callers ofmma_fp8are updated to provide the correct template argument.include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh (4)
64-67: LGTM! Consistent sparse-path layout naming.The layout aliases are correctly updated to use
SmemLayoutVSparseandSmemLayoutVtSparse, maintaining consistency with the new naming conventions inkernel_traits.cuh.
194-202: LGTM! Correct sparse transpose implementation.The transpose source/target layouts and transposer are correctly updated to use sparse-specific variants that handle the
(CTA_KV, HEAD_DIM, STAGES)layout from cp.async loading.
408-437: Verify parity toggle logic in the main loop.The parity double-buffering pattern is used to overlap prefetch and load operations. The toggle sequence appears correct (
parity ^= 1before prefetch,parity ^= 1before load from previous prefetch,parity ^= 1before next load), but verify the invariant is maintained across iterations.
280-298: Guard against out-of-bounds page table access at line 292.Validating
kv_idx_read < kv_lendoes not guarantee thatpage_iterwill be within bounds ofkv_indices_ptr. Whenkv_idx_readis valid butpage_iterexceeds the page table size, accessingkv_indices_ptr[page_iter]reads garbage data. Add bounds checking onpage_iteror ensurekv_indices_ptrsize is guaranteed to accommodate all possiblepage_itervalues derived from the page size divmod.include/flashinfer/attention/hopper/sm90_pipeline_no_cluster.cuh (1)
95-100: LGTM! Efficient single-thread barrier signaling.The
consumer_releaseimplementation correctly selects only threads at warp-group boundaries (threadIdx.x % NumThreadsPerWarpGroup == 0) to signal the barrier, avoiding the all-threads signaling that caused the regression in Cutlass v3.6.0.include/flashinfer/attention/hopper/quantization/kernel_traits.cuh (6)
50-60: Excellent documentation of FA3-style transpose approach.The comment clearly explains the key insight about TMA loading V with transposed gmem strides and the difference between TMA and sparse paths. This will be valuable for future maintainers.
97-99: Verify the assertion covers all valid configurations.The static_assert requires either
kHeadDimorkBlockNto be a multiple of 64, which aligns with the 64x64 transpose atom. Confirm this constraint is documented for users configuring custom tile shapes.
189-236: LGTM! Clean FA3-style transpose implementation.The
do_transposemethod correctly:
- Uses LDSM.T for loading from MN-major smem
- Applies byte permutation for FP8 element reordering
- Uses STSM for storing to K-major smem
- Includes ILP optimization for better instruction-level parallelism
397-403: Good documentation explaining SharedStorage layout choice.The comment clearly explains why
SmemLayoutVSparseis used for SharedStorage (to support the sparse/paged KV path) while noting that both layouts have the same cosize for memory allocation.
305-308: Verify loop bounds match layout dimensions.The nested loops iterate over
shape<2>andshape<1>ofSmemLayoutVSparseTransposeSrc. Confirm these dimensions correspond to the correct tile subdivisions and that the loop ordering (outer loop overshape<2>, inner loop overshape<1>) is appropriate for the transpose operation.
390-403: Verify ifPipelineTmaAsyncNoClusterexists as a custom implementation in flashinfer.The recommended replacement
PipelineTmaAsyncNoCluster<NUM_STAGES>does not appear in standard CUTLASS sources or documentation. If this is a custom wrapper or variant specific to flashinfer (e.g., defined insm90_pipeline_no_cluster.cuh), verify that both the custom class and the header file exist in the codebase before applying this change. Additionally, confirm that the Cutlass v3.6.0 regression affectingPipelineTmaAsyncis accurately characterized.benchmarks/bench_fa3_comparison.py (5)
19-44: LGTM!The helper functions are well-structured:
calc_tflopscorrectly computes TFLOPS using the standard attention formula (4 × batch × seq² × heads × dim), with appropriate scaling for causal masks.calc_tflops_varlenproperly handles variable-length sequences by summing per-sequence FLOPs.bench_fnreturns the median time, which is a robust choice for benchmarking.
250-261: LGTM!The per-head symmetric quantization implementation is correct:
- Properly handles both FP8 formats (e4m3fn and e5m2) with appropriate ranges
- Uses per-head scaling (reducing over dims 0 and 2)
- Includes safeguards against division by zero with
min=1e-6clamping- Correctly broadcasts scale factors for element-wise quantization
1-718: Well-structured benchmarking script.This is a comprehensive benchmarking suite that effectively compares FlashInfer and FA3 across multiple scenarios:
- Various batch sizes and sequence lengths
- Variable-length sequences
- FP8 quantization paths
- Paged KV cache with different page sizes
- Graceful handling of implementation differences
The progressive test configurations and clear output formatting make this a valuable benchmarking tool.
13-13: [rewritten comment]
[classification tag]
673-683: The q_descale computation for GQA is correct and matches FA3's API expectations.FA3's
flash_attn_with_kvcacheexpects descale parameters with shape(num_sequences, num_kv_heads). For GQA, this means one descale value per KV head group. The code correctly reshapess_qfrom[num_qo_heads]to[num_kv_heads, num_qo_heads // num_kv_heads], takes the mean across query heads within each KV group, and expands to(batch_size, num_kv_heads). This aligns with FA3's requirement that descales be per-KV-head for GQA models, where the CUDA kernel internally expands them to match per-query entries.
| template <typename TileShape_QKD, typename Element, int NUM_STAGES> | ||
| struct TranposeTraits_64x64 { | ||
| using TransposeShapeAtom_ = Shape<_64, _64>; | ||
| using TransElement = Element; | ||
| static_assert(cutlass::sizeof_bits_v<TransElement> == 8); | ||
|
|
||
| using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>; | ||
| using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>; | ||
| static constexpr int kHeadDim = get<2>(TileShape_QKD{}); | ||
| static constexpr int kBlockN = get<1>(TileShape_QKD{}); | ||
|
|
||
| // MN-major for TMA loading (V is loaded with transposed gmem strides) | ||
| static constexpr cute::GMMA::Major TmaMajorV = GMMA::Major::MN; | ||
| // K-major for MMA consumption (required for FP8) | ||
| static constexpr cute::GMMA::Major MmaMajorV = GMMA::Major::K; | ||
|
|
||
| // ==================== TMA Path Layouts (FA3-style same-shape) ==================== | ||
| // SmemLayoutVtTma: MN-major layout for TMA load, shape (HEAD_DIM, CTA_KV, STAGES) | ||
| using SmemLayoutAtomVtTma = | ||
| decltype(cutlass::gemm::collective::detail::ss_smem_selector<TmaMajorV, Element, | ||
| Int<kHeadDim>, Int<kBlockN>>()); | ||
| using SmemLayoutVtTma = decltype(tile_to_shape( | ||
| SmemLayoutAtomVtTma{}, make_shape(Int<kHeadDim>{}, Int<kBlockN>{}, Int<NUM_STAGES>{}), | ||
| cute::Step<_2, _1, _3>{})); // MN-major ordering | ||
|
|
||
| // SmemLayoutVtMma: K-major layout for MMA, same shape (HEAD_DIM, CTA_KV, STAGES) | ||
| using SmemLayoutAtomVtMma = | ||
| decltype(cutlass::gemm::collective::detail::ss_smem_selector<MmaMajorV, Element, | ||
| Int<kHeadDim>, Int<kBlockN>>()); | ||
| using SmemLayoutVtMma = decltype(tile_to_shape( | ||
| SmemLayoutAtomVtMma{}, make_shape(Int<kHeadDim>{}, Int<kBlockN>{}, Int<NUM_STAGES>{}), | ||
| cute::Step<_1, _2, _3>{})); // K-major ordering | ||
|
|
||
| // For TMA path: SmemLayoutV = SmemLayoutVtTma (MN-major, for TMA load) | ||
| using SmemLayoutV = SmemLayoutVtTma; | ||
| using SmemLayoutVt = SmemLayoutVtMma; | ||
|
|
||
| // FA3-style LDSM/STSM tiled copies for TMA path transpose | ||
| static constexpr bool kHeadDimMultiple64 = kHeadDim % 64 == 0; | ||
| static_assert(kHeadDimMultiple64 || kBlockN % 64 == 0, | ||
| "Either kHeadDim or kBlockN must be multiple of 64"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in struct name: TranposeTraits should be TransposeTraits.
The struct name has a typo (missing 's' in "Transpose").
-template <typename TileShape_QKD, typename Element, int NUM_STAGES>
-struct TranposeTraits_64x64 {
+template <typename TileShape_QKD, typename Element, int NUM_STAGES>
+struct TransposeTraits_64x64 {Also update the alias at line 373:
- using VTranposeTraits = TranposeTraits_64x64<TileShape_QKD, DTypeKV, NUM_STAGES>;
+ using VTransposeTraits = TransposeTraits_64x64<TileShape_QKD, DTypeKV, NUM_STAGES>;Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In include/flashinfer/attention/hopper/quantization/kernel_traits.cuh around
lines 61 to 99, rename the struct from TranposeTraits_64x64 to
TransposeTraits_64x64 (fix the missing 's') and update any corresponding
internal uses/typedefs to match the corrected name; also update the alias
referenced at line 373 to use TransposeTraits_64x64 instead of the old
misspelled name so all references remain consistent.
| } | ||
|
|
||
| inline float cost_function(int qo_len, int kv_len) { return 2 * float(qo_len) + kv_len; } | ||
| inline float cost_function(int qo_len, int kv_len) { return 0.05 * float(qo_len) + kv_len; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this based on profiling? This will also change the cost for persistent FA2?
(I tried fitting a quadratic cost function over a bunch of seq lens for persistent FA2 + overlap, didn't work well)
I think FA4 just strides by num CTAs without any load balancing? I plan to try opportunistically fetching from the two work queues next in the persistent FA2, like in POD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it's based on profiling.
But as you mention, none of them work well (the cost model can not model the interference of L2, etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FA3's dynamic LPT scheduler (it might not be necessary to still stride by num_ctas)
https://github.com/Dao-AILab/flash-attention/blob/c783ab2f7e05ba1cd79ecfe0e6e109a4e3f6e542/hopper/tile_scheduler.hpp#L338
|
/bot run |
|
[SUCCESS] Pipeline #39915642: 11/20 passed |
📌 Description
Follow up of #2111 , adding more optimizations in fa3 (mostly around fp8).
🔍 Related Issues
#2111
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
cc @Edenzzzz
Summary by CodeRabbit
Release Notes
New Features
Performance Improvements
✏️ Tip: You can customize this high-level summary in your review settings.