@@ -138,6 +138,8 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q,
138138 int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br;
139139 if (load_gmem_Q_Br >= QKV_seqlen) return ;
140140 constexpr bool kIsVCanLoadIn128b = (kHeadDim / (kNumThreads / kMmaAtomK )) % 8 == 0 ;
141+ constexpr bool kIsVCanLoadIn64b = (kHeadDim / (kNumThreads / kMmaAtomK )) % 4 == 0 ;
142+ static_assert (kIsVCanLoadIn128b || kIsVCanLoadIn64b , " V can't load in 128b or 64b." ); // 32,64,128,192,256,...
141143
142144 // Shared memory for Q,K,V, we don not need additional smem for O
143145 // collective store which perform via registers reuse and warp shuffle.
@@ -763,17 +765,17 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q,
763765template <const int kHeadDim , const int kStage >
764766void launch_flash_attn_mma_stages_split_q_tiling_qk (
765767 torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
766- // Now: fixed tile BrxBc=128x128
768+ // Now: fixed tile BrxBc=128x128 for d>= 128, 64x64 for d<128.
767769 // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
768770 constexpr int kMmaAtomM = 16 ;
769771 constexpr int kMmaAtomN = 8 ;
770772 constexpr int kMmaAtomK = 16 ;
771- constexpr int kMmaTileSeqLenQ = 8 ;
773+ constexpr int kMmaTileSeqLenQ = ( kHeadDim < 128 ) ? 4 : 8 ;
772774 constexpr int kMmaTileSeqLenK = 1 ;
773- constexpr int kMmaTileSeqLenP = 8 ;
775+ constexpr int kMmaTileSeqLenP = ( kHeadDim < 128 ) ? 4 : 8 ;
774776 constexpr int kMmaTileHeadDimV = 1 ;
775777 constexpr int kWarpTileSeqLenQ = 1 ;
776- constexpr int kWarpTileSeqLenK = 16 ;
778+ constexpr int kWarpTileSeqLenK = ( kHeadDim < 128 ) ? 8 : 16 ;
777779 constexpr int kWarpTileSeqLenP = 1 ;
778780 constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV )); // (d=64)8,(d=128)16,32,....
779781 constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ ; // 16*4*1=64
0 commit comments