Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
718 changes: 718 additions & 0 deletions benchmarks/bench_fa3_comparison.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion include/flashinfer/attention/hopper/kernel_traits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "sm90_pipeline_no_cluster.cuh"

namespace flashinfer {

Expand Down Expand Up @@ -110,8 +111,10 @@ struct AttentionKernelTraits {
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
// Use PipelineTmaAsyncNoCluster for TMA loads to avoid perf regression in Cutlass 3.6+
// Only 1 out of 128 threads signals the barrier (instead of all threads)
using MainloopPipeline =
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
std::conditional_t<USE_TMA_LOAD_KV, PipelineTmaAsyncNoCluster<NUM_STAGES>,
typename cutlass::PipelineAsync<NUM_STAGES>>;
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;

Expand Down
18 changes: 10 additions & 8 deletions include/flashinfer/attention/hopper/mainloop.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ struct CollectiveMainloop {
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k),
/*mcast_mask=*/0),
/*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST),
tKgK(_, kv_tile_idx), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
Expand Down Expand Up @@ -230,22 +230,24 @@ struct CollectiveMainloop {
#pragma unroll 2
for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k),
/*mcast_mask=*/0),
tKgK(_, kv_tile_idx - 1), tKsK(_, smem_pipe_write_k.index()));
copy(
mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k),
/*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST),
tKgK(_, kv_tile_idx - 1), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v),
/*mcast_mask=*/0),
tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index()));
copy(
mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v),
/*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST),
tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v),
/*mcast_mask=*/0),
/*mcast_mask=*/0, cute::TMA::CacheHintSm90::EVICT_LAST),
tVgV(_, kv_tile_idx), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
Expand Down
276 changes: 138 additions & 138 deletions include/flashinfer/attention/hopper/mainloop_mma.cuh

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion include/flashinfer/attention/hopper/prefill_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaS
using CollectiveMainloop =
CollectiveMainloop<typename Params::AdditionalParams, KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
using Scheduler = SingleTileScheduler;
// Use LPT scheduling for causal attention for better load balancing
using Scheduler = SingleTileScheduler</*LPT=*/CAUSAL>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments(
{params.q_ptr,
get_gmem_layout(params.qo_len, params.num_qo_heads, KernelTraits::HEAD_DIM_QK,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "../named_barrier.cuh"
#include "../utils.cuh"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"

namespace flashinfer {
Expand All @@ -39,7 +40,13 @@ struct FP8CollectiveEpilogue {
decltype(cute::get<2>(TileShape_QKD{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_QKD{})));

using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
using StrideO = cute::Shape<int64_t, _1, int64_t>;
using EpilogueTile_MN = decltype(select<0, 2>(TileShape_QKD{}));
// Use sm90_get_smem_store_op_for_accumulator to get the correct copy op for FP8 accumulators
using CopyOpR2S =
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<
StrideO, DTypeO, EpilogueTile_MN>());
using SmemCopyAtomO = Copy_Atom<CopyOpR2S, DTypeO>;
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;

using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
Expand Down
Loading