Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk());
using TileShapePV = decltype(TiledMMAPV{}.tile_mnk());
static constexpr int VTiles = VTiles_;

using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk());
using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{}))));

using TensorQ = TensorQ_;
Expand Down Expand Up @@ -171,8 +171,10 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
QVCoord blk_qv, // WG tile indices: (Q,V)
int blk_k0, // K block range: [K0,K1)
int blk_k1,
int thr_id) { // Work-item ID

int thr_id,
int seq_len,
int full_tile_offset,
int discard_seq_coord) {
using namespace sycl::ext::oneapi::this_work_item;

// Short dimension names:
Expand Down Expand Up @@ -266,7 +268,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
}

/* Check if */
bool check_remainder_k = (shape<0>(K_2D) % get<1>(TileShapeQK{}) != 0);
bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0);

/* Main loop, blocked in k. */
for (int K = blk_k0; K < blk_k1; K++) {
Expand All @@ -288,23 +290,37 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
/* V prefetch for GEMM 2 */
prefetch(prefetch_v, pVgV(_,_,_,K));

/* Causal masking */
if constexpr (CausalMask) {
if (K == blk_k1 - 1) {
// Need to get global col and row indices to mask the elements
Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len));
Tensor gP = local_tile(cPgP, take<0,2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K));
auto cS_thread = thr_mma_qk.partition_C(gP);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < tSrS.size(); ++i) {
int row_idx = get<0>(cS_thread(i));
int col_idx = get<1>(cS_thread(i));
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
tSrS(i) = ElementS(-INFINITY);
}
}
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClarkChin08 Thanks for updating this! By the way, we can make the code even cleaner by including the block offset in cS_thread itself. Something like this should do it:

Tensor gP = local_tile(cP, TileShapeQK{}, blk_qv);
auto cS_thread = thr_mma_qk.partition_C(gP);

Then you don't need to do the blocking calculations here; instead row_idx = get<0>(cS_thread(i)), col_idx = get<1>(cS_thread(i)).

Copy link
Author

@ClarkChin08 ClarkChin08 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @petercad , I changed to use local_tile to get global col and row indices.

/* k masking for remainder tiles */
if (check_remainder_k && K == blk_k1 - 1) {
FragSRow k_rem_mask;
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < tSrS.size(); i++) {
tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i));
}
}

/* TODO: causal masking */
static_assert(!CausalMask, "Causal mask unimplemented");

/* Apply softmax and scaling */
softmax(K == 0, tSrS, tA_max, tA_sum, tArA);
#if 0
Expand Down
24 changes: 20 additions & 4 deletions applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class XeFMHAFwdKernel {
using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV;
using TileShapeQK = typename CollectiveMainloop::TileShapeQK;
using TileShapePV = typename CollectiveMainloop::TileShapePV;

using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK;
using ElementQ = typename CollectiveMainloop::TensorQ::element_type;
using ElementK = typename CollectiveMainloop::TensorK::element_type;
using ElementV = typename CollectiveMainloop::TensorV::element_type;
Expand Down Expand Up @@ -181,6 +181,13 @@ class XeFMHAFwdKernel {
int head_group_q = s.num_heads_q / s.num_heads_kv;

int thr_id = int(ThreadIdxX());
int sub_group_id = thr_id / intel::sg_size;
int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})));

auto cS = make_identity_tensor(take<0,2>(TiledMMAQK{}.tile_mnk()));
auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS);
auto q_offset_wi = get<0>(tScS(0));
auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0);

TileScheduler tile_scheduler{params.scheduler};

Expand All @@ -190,7 +197,16 @@ class XeFMHAFwdKernel {
auto blk_qv = make_coord(blk_q, blk_v);
int head = head_q / head_group_q;

const int k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{}));
auto offset = cute::min(s.seq_len_qo, s.seq_len_kv);
auto discard_seq_coord = s.seq_len_qo - offset;
auto full_tile_offset = s.seq_len_kv - offset;

int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg));

if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue;

const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(s.seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : s.seq_len_kv;
const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{}));

auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch);
auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch);
Expand All @@ -217,8 +233,8 @@ class XeFMHAFwdKernel {
V(_,_,head,idx_b),
tArA, tA_max, tA_sum,
blk_qv, 0, k_blocks,
thr_id);

thr_id, seq_len,
full_tile_offset, discard_seq_coord);
if constexpr (!is_empty_v<MainloopSharedStorage> && !is_empty_v<EpilogueSharedStorage>) {
sycl::group_barrier(get_work_group<3>());
}
Expand Down
5 changes: 4 additions & 1 deletion examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,8 @@ int main(int argc, const char **argv) {
using ElementK = bfloat16_t;
using ElementV = bfloat16_t;
#endif
return FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options);

return options.is_causal ? FMHAConfig<true, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options)
: FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options);

}
5 changes: 5 additions & 0 deletions include/cute/atom/mma_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ struct TiledMMA : MMA_Atom
return thr_layout_vmnk_;
}

CUTE_HOST_DEVICE constexpr auto
get_atom_layout_mnk() const {
return AtomLayoutMNK{};
}

// Tile a tensor or a layout from shape
// (M,N,...)
// to shape
Expand Down