From 845a32773aa15ab4b5b7e06482e0ab61bb623878 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Mon, 8 Apr 2024 11:31:00 +0200 Subject: [PATCH] optimize compilation --- CMakeLists.txt | 5 - .../ops/flash-attention/flash_fwd_kernel.h | 10 +- .../flash_fwd_launch_template.h | 246 ++++++++---------- .../ops/flash-attention/kernel_traits.h | 184 ------------- 4 files changed, 113 insertions(+), 332 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 86ccf8e5a..bd37cc510 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -490,11 +490,6 @@ if (WITH_CUDA) cuda_include_directories(${THRUST_INCLUDE_DIRS}) list(APPEND PRIVATE_INCLUDE_DIRECTORIES ${THRUST_INCLUDE_DIRS}) - #if(NOT DEFINED ${CMAKE_CUDA_ARCHITECTURES}) - # set(CMAKE_CUDA_ARCHITECTURES 70) - #endif() - #set(CMAKE_CUDA_COMPILER ${CUDA_HOST_COMPILER}) - #add_subdirectory(third_party/cutlass) set(CUTLASS_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include ) diff --git a/include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h b/include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h index 12d4ede47..4bff64f07 100644 --- a/include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h +++ b/include/ctranslate2/ops/flash-attention/flash_fwd_kernel.h @@ -23,7 +23,7 @@ namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// - template + template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -375,7 +375,7 @@ namespace flash { // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = flash::convert_type(acc_o); @@ -931,7 +931,7 @@ namespace flash { // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1009,7 +1009,7 @@ namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// - template + template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1025,7 +1025,7 @@ namespace flash { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h b/include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h index 8ec0024ac..6f14be1bd 100644 --- a/include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h +++ b/include/ctranslate2/ops/flash-attention/flash_fwd_launch_template.h @@ -10,36 +10,40 @@ #include "ctranslate2/devices.h" #include "cuda/utils.h" -template -#if __CUDA_ARCH__ >= 700 +template +#if __CUDA_ARCH__ >= 800 __global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + flash::compute_attn(params); +} #else __global__ void flash_fwd_kernel(const Flash_fwd_params params) { -#endif - static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - flash::compute_attn(params); } +#endif + template -#if __CUDA_ARCH__ >= 700 +#if __CUDA_ARCH__ >= 800 __global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) { + flash::compute_attn_splitkv(params); +} #else __global__ void flash_fwd_splitkv_kernel(const Flash_fwd_params params) { -#endif - flash::compute_attn_splitkv(params); } +#endif template -#if __CUDA_ARCH__ >= 700 +#if __CUDA_ARCH__ >= 800 __global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) { -#else -__global__ void flash_fwd_splitkv_combine_kernel(const Flash_fwd_params params) { -#endif static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); } +#else +__global__ void flash_fwd_splitkv_combine_kernel(const Flash_fwd_params params) { +} +#endif -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr size_t smem_size = Kernel_traits::kSmemSize; // printf("smem_size = %d\n", smem_size); @@ -63,7 +67,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -159,32 +163,21 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); }); } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 256) is 27% slower for seqlen=2k - // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } @@ -194,24 +187,22 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); bool is_sm8x = dprops.major == 8 && dprops.minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } @@ -221,35 +212,26 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); bool is_sm8x = dprops.major == 8 && dprops.minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - }); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy }); } @@ -259,47 +241,39 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); bool is_sm8x = dprops.major == 8 && dprops.minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } @@ -315,20 +289,18 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. - // If we have N = 32, there are only 1024 elements to load at once, where each load - // is 8 elements. This means we can only use 128 threads and not 256 threads. - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } @@ -346,19 +318,17 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, we want to run with 128 x 64 (128KB smem). - // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // 64 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 96 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } diff --git a/include/ctranslate2/ops/flash-attention/kernel_traits.h b/include/ctranslate2/ops/flash-attention/kernel_traits.h index a7a5cf1ed..00e772d0b 100644 --- a/include/ctranslate2/ops/flash-attention/kernel_traits.h +++ b/include/ctranslate2/ops/flash-attention/kernel_traits.h @@ -157,188 +157,4 @@ struct Flash_fwd_kernel_traits : public Base { GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load }; - -// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. -// No_double_buffer is another option to reduce smem usage, but will slow things down. -template > -struct Flash_bwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Is_V_in_regs = Is_V_in_regs_; - static constexpr bool No_double_buffer = No_double_buffer_; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; - static_assert(kNWarps % AtomLayoutMSdP == 0); - static_assert(kNWarps % AtomLayoutNdKV == 0); - static_assert(kNWarps % AtomLayoutMdQ == 0); - - using TiledMmaSdP = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, - Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; - - using TiledMmadKV = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, - Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; - - using TiledMmadQ = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group - Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; - - using SmemLayoutAtomQdO = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutQdO = decltype(tile_to_shape( - SmemLayoutAtomQdO{}, - make_shape(Int{}, Int{}))); - - using SmemLayoutAtomKV = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutKV = decltype(tile_to_shape( - // SmemLayoutAtomQdO{}, - SmemLayoutAtomKV{}, - make_shape(Int{}, Int{}))); - - using SmemLayoutKtransposed = decltype( - composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); - - // TODO: generalize to other values of kBlockN - // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 - // static constexpr int kPBlockN = kBlockN; - // Temporarily disabling this for hdim 256 on sm86 and sm89 - // static_assert(kBlockN >= 64); - static_assert(kBlockN >= 32); - // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; - static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); - // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); - static constexpr int kSwizzlePdS = 3; - using SmemLayoutAtomPdS = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutPdS = decltype(tile_to_shape( - SmemLayoutAtomPdS{}, - make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposed = decltype( - composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - - using SmemCopyAtomPdS = Copy_Atom; - - using SmemLayoutQdOtransposed = decltype( - composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); - - using SmemLayoutAtomdKV = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutdKV = decltype(tile_to_shape( - SmemLayoutAtomdKV{}, - make_shape(Int{}, Int{}))); - using SmemCopyAtomdKV = Copy_Atom; - - using SmemLayoutAtomdQ = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - using SmemLayoutdQ = decltype(tile_to_shape( - SmemLayoutAtomdQ{}, - make_shape(Int{}, Int{}))); - using SmemCopyAtomdQ = Copy_Atom; - - // Double buffer for sQ - static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); - static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); - static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); - static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); - static constexpr int kSmemSize = kSmemQdOSize - + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); - static constexpr int kSmemSize1colblock = kSmemQdOSize - + (!Is_V_in_regs - ? kSmemKVSize + kSmemdSSize + kSmemPSize - : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem - // to affect speed in practice. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopydO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemTiledCopydQ = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - using GmemLayoutAtomdQaccum = std::conditional_t< - kBlockKSmem == 32, - Layout, // Thread layout, 8 threads per row - Stride< _8, _1>>, - Layout, // Thread layout, 16 threads per row - Stride< _16, _1>> - >; - using GmemTiledCopydQaccum = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomdQaccum{}, - Layout>{})); // Val layout, 4 vals per store - - using GmemTiledCopydQaccumAtomicAdd = decltype( - make_tiled_copy(Copy_Atom{}, - Layout, // Thread layout, 8 threads per row - Stride<_32, _1>>{}, - Layout>{})); // Val layout, 1 val per store - -}; - ////////////////////////////////////////////////////////////////////////////////////////////////////