Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk());
using TileShapePV = decltype(TiledMMAPV{}.tile_mnk());
static constexpr int VTiles = VTiles_;

using MmaAtomShapeQK = typename TiledMMAQK::AtomShape_MNK;
using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk());
using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{}))));

using TensorQ = TensorQ_;
Expand Down Expand Up @@ -171,8 +172,10 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
QVCoord blk_qv, // WG tile indices: (Q,V)
int blk_k0, // K block range: [K0,K1)
int blk_k1,
int thr_id) { // Work-item ID

int thr_id,
int seq_len,
int full_tile_offset,
int discard_seq_coord) {
using namespace sycl::ext::oneapi::this_work_item;

// Short dimension names:
Expand Down Expand Up @@ -219,6 +222,8 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D)
auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K)

auto cS_thread = thr_mma_qk.partition_C(cP);

/* Create register fragments for MMA and copies */
auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_,_,0));
auto tSrQ = thr_mma_qk.partition_sg_fragment_A(gQ(_,_,0));
Expand Down Expand Up @@ -266,7 +271,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
}

/* Check if */
bool check_remainder_k = (shape<0>(K_2D) % get<1>(TileShapeQK{}) != 0);
bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0);

/* Main loop, blocked in k. */
for (int K = blk_k0; K < blk_k1; K++) {
Expand All @@ -288,23 +293,34 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
/* V prefetch for GEMM 2 */
prefetch(prefetch_v, pVgV(_,_,_,K));

/* Causal masking */
if constexpr (CausalMask) {
if (K == blk_k1 - 1) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < tSrS.size(); ++i) {
// Need to get global col and row indices to mask the elements
int row_idx = get<0>(cS_thread(i)) + get<0>(blk_qv) * get<0>(TileShapeQK{});
int col_idx = get<1>(cS_thread(i)) + K * get<1>(TileShapeQK{});
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
tSrS(i) = ElementS(-INFINITY);
}
}
}
}
/* k masking for remainder tiles */
if (check_remainder_k && K == blk_k1 - 1) {
FragSRow k_rem_mask;
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < tSrS.size(); i++) {
tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i));
}
}

/* TODO: causal masking */
static_assert(!CausalMask, "Causal mask unimplemented");

/* Apply softmax and scaling */
softmax(K == 0, tSrS, tA_max, tA_sum, tArA);
#if 0
Expand Down
20 changes: 15 additions & 5 deletions applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class XeFMHAFwdKernel {
using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV;
using TileShapeQK = typename CollectiveMainloop::TileShapeQK;
using TileShapePV = typename CollectiveMainloop::TileShapePV;

using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK;
using ElementQ = typename CollectiveMainloop::TensorQ::element_type;
using ElementK = typename CollectiveMainloop::TensorK::element_type;
using ElementV = typename CollectiveMainloop::TensorV::element_type;
Expand Down Expand Up @@ -181,7 +181,8 @@ class XeFMHAFwdKernel {
int head_group_q = s.num_heads_q / s.num_heads_kv;

int thr_id = int(ThreadIdxX());

int sub_group_id = thr_id / intel::sg_size;
int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})));
TileScheduler tile_scheduler{params.scheduler};

CUTLASS_PRAGMA_NO_UNROLL
Expand All @@ -190,7 +191,16 @@ class XeFMHAFwdKernel {
auto blk_qv = make_coord(blk_q, blk_v);
int head = head_q / head_group_q;

const int k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{}));
auto offset = cute::min(s.seq_len_qo, s.seq_len_kv);
auto discard_seq_coord = s.seq_len_qo - offset;
auto full_tile_offset = s.seq_len_kv - offset;

int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ));

if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue;

const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(s.seq_len_kv, seq_coord - discard_seq_coord) + SGTileQ : s.seq_len_kv;
const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{}));

auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch);
auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch);
Expand All @@ -217,8 +227,8 @@ class XeFMHAFwdKernel {
V(_,_,head,idx_b),
tArA, tA_max, tA_sum,
blk_qv, 0, k_blocks,
thr_id);

thr_id, seq_len,
full_tile_offset, discard_seq_coord);
if constexpr (!is_empty_v<MainloopSharedStorage> && !is_empty_v<EpilogueSharedStorage>) {
sycl::group_barrier(get_work_group<3>());
}
Expand Down
5 changes: 4 additions & 1 deletion examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,8 @@ int main(int argc, const char **argv) {
using ElementK = bfloat16_t;
using ElementV = bfloat16_t;
#endif
return FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options);

return options.is_causal ? FMHAConfig<true, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options)
: FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options);

}
2 changes: 1 addition & 1 deletion examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ struct FMHAConfig {
TiledMMAQK, TiledMMAPV, VTiles,
TensorQ, TensorK, TensorV,
GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV
>;
>;

// Epilogue
using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogue<
Expand Down
9 changes: 8 additions & 1 deletion include/cute/atom/mma_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,24 @@ struct TiledMMA : MMA_Atom

using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{}));
ThrLayoutVMNK thr_layout_vmnk_;
AtomLayoutMNK atom_layout_mnk_;

CUTE_HOST_DEVICE constexpr
TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {})
: MMA_Atom(mma_atom),
thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {}
thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)),
atom_layout_mnk_(thr_layout_mnk) {}

CUTE_HOST_DEVICE constexpr auto
get_thr_layout_vmnk() const {
return thr_layout_vmnk_;
}

CUTE_HOST_DEVICE constexpr auto
get_atom_layout_mnk() const {
return atom_layout_mnk_;
}

// Tile a tensor or a layout from shape
// (M,N,...)
// to shape
Expand Down