From a1346c5894d3ca9232898fa1f9295cee11914085 Mon Sep 17 00:00:00 2001 From: "Chen, Xi2" Date: Mon, 3 Nov 2025 09:21:43 +0000 Subject: [PATCH 1/5] add CausalMask support with new flash attention api Signed-off-by: Chen, Xi2 --- .../collective/xe_fmha_fwd_mainloop.hpp | 44 ++++++++++++++----- .../kernel/xe_fhma_fwd_kernel.hpp | 20 ++++++--- .../06_bmg_flash_attention/06_xe_fmha_fwd.cpp | 6 ++- .../xe_fmha_fwd_runner.hpp | 3 +- 4 files changed, 55 insertions(+), 18 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index b2c802da4b..dd93e30814 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -62,7 +62,8 @@ template // Optional TiledCopy for loading V + class TiledCopyV_ = void, // Optional TiledCopy for loading V + class SubgroupLayoutQK_ = void> // Optional SubgroupLayout for QK struct FMHAFwdMainloop { static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; @@ -73,11 +74,13 @@ template + class TiledCopyQ_, class TiledCopyK_, class TiledCopyV_, + class SubgroupLayoutQK_> struct FMHAFwdMainloop, CausalMask_, TiledMMAQK_, TiledMMAPV_, VTiles_, TensorQ_, TensorK_, TensorV_, - TiledCopyQ_, TiledCopyK_, TiledCopyV_> { + TiledCopyQ_, TiledCopyK_, TiledCopyV_, + SubgroupLayoutQK_> { // // Type Aliases // @@ -86,7 +89,8 @@ struct FMHAFwdMainloop, CausalMask_, using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk()); using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); static constexpr int VTiles = VTiles_; - + using MmaAtomShapeQK = typename TiledMMAQK::AtomShape_MNK; + using SubgroupLayoutQK = SubgroupLayoutQK_; using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{})))); using TensorQ = TensorQ_; @@ -171,8 +175,11 @@ struct FMHAFwdMainloop, CausalMask_, QVCoord blk_qv, // WG tile indices: (Q,V) int blk_k0, // K block range: [K0,K1) int blk_k1, - int thr_id) { // Work-item ID - + int thr_id, + int seq_len, + int seq_coord, + int full_tile_offset, + int discard_seq_coord) { using namespace sycl::ext::oneapi::this_work_item; // Short dimension names: @@ -266,7 +273,7 @@ struct FMHAFwdMainloop, CausalMask_, } /* Check if */ - bool check_remainder_k = (shape<0>(K_2D) % get<1>(TileShapeQK{}) != 0); + bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0); /* Main loop, blocked in k. */ for (int K = blk_k0; K < blk_k1; K++) { @@ -288,13 +295,31 @@ struct FMHAFwdMainloop, CausalMask_, /* V prefetch for GEMM 2 */ prefetch(prefetch_v, pVgV(_,_,_,K)); + /* Causal masking */ + if constexpr (CausalMask) { + if (K == blk_k1 - 1) { + int item_id = get_sub_group().get_local_id()[0]; + int base_col = item_id + K * get<1>(TileShapeQK{}); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < shape<2>(tSrS.shape()); ++n) { + int col_idx = base_col + n * get<1>(MmaAtomShapeQK()); + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < shape<0>(tSrS.shape()); ++m) { + int row_idx = seq_coord + m; + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + tSrS(m, 0, n) = ElementS(-INFINITY); + } + } + } + } + } /* k masking for remainder tiles */ if (check_remainder_k && K == blk_k1 - 1) { FragSRow k_rem_mask; int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { - k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); + k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < tSrS.size(); i++) { @@ -302,9 +327,6 @@ struct FMHAFwdMainloop, CausalMask_, } } - /* TODO: causal masking */ - static_assert(!CausalMask, "Causal mask unimplemented"); - /* Apply softmax and scaling */ softmax(K == 0, tSrS, tA_max, tA_sum, tArA); #if 0 diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index fced70ee84..e5f3d84a3e 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -72,7 +72,7 @@ class XeFMHAFwdKernel { using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; using TileShapeQK = typename CollectiveMainloop::TileShapeQK; using TileShapePV = typename CollectiveMainloop::TileShapePV; - + using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; using ElementQ = typename CollectiveMainloop::TensorQ::element_type; using ElementK = typename CollectiveMainloop::TensorK::element_type; using ElementV = typename CollectiveMainloop::TensorV::element_type; @@ -181,7 +181,8 @@ class XeFMHAFwdKernel { int head_group_q = s.num_heads_q / s.num_heads_kv; int thr_id = int(ThreadIdxX()); - + int sub_group_id = thr_id / intel::sg_size; + int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); TileScheduler tile_scheduler{params.scheduler}; CUTLASS_PRAGMA_NO_UNROLL @@ -190,7 +191,16 @@ class XeFMHAFwdKernel { auto blk_qv = make_coord(blk_q, blk_v); int head = head_q / head_group_q; - const int k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{})); + auto offset = cute::min(s.seq_len_qo, s.seq_len_kv); + auto discard_seq_coord = s.seq_len_qo - offset; + auto full_tile_offset = s.seq_len_kv - offset; + + int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ)); + + if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; + + const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(s.seq_len_kv, seq_coord - discard_seq_coord) + SGTileQ : s.seq_len_kv; + const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch); auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch); @@ -217,8 +227,8 @@ class XeFMHAFwdKernel { V(_,_,head,idx_b), tArA, tA_max, tA_sum, blk_qv, 0, k_blocks, - thr_id); - + thr_id, seq_len, + seq_coord, full_tile_offset, discard_seq_coord); if constexpr (!is_empty_v && !is_empty_v) { sycl::group_barrier(get_work_group<3>()); } diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp index d21beffaf4..c50717ea1d 100644 --- a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -147,6 +147,7 @@ int main(int argc, const char **argv) { #else constexpr int PipelineStages = 2; #endif + #ifdef IS_FLOAT_E5M2 using ElementQ = cutlass::float_e5m2_t; using ElementK = cutlass::float_e5m2_t; @@ -160,5 +161,8 @@ int main(int argc, const char **argv) { using ElementK = bfloat16_t; using ElementV = bfloat16_t; #endif - return FMHAConfig::run(options); + + return options.is_causal ? FMHAConfig::run(options) + : FMHAConfig::run(options); + } diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 64c37ca4c7..25e147c954 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -576,7 +576,8 @@ struct FMHAConfig { MainloopDispatchPolicy, Causal, TiledMMAQK, TiledMMAPV, VTiles, TensorQ, TensorK, TensorV, - GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV + GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, + SubgroupLayoutQK >; // Epilogue From f8a05144b8a109617c24f554aa07bbebe81475d5 Mon Sep 17 00:00:00 2001 From: "Chen, Xi2" Date: Mon, 10 Nov 2025 08:07:54 +0000 Subject: [PATCH 2/5] refine causal mask in new fa Signed-off-by: Chen, Xi2 --- .../collective/xe_fmha_fwd_mainloop.hpp | 29 ++++++++----------- .../kernel/xe_fhma_fwd_kernel.hpp | 2 +- .../xe_fmha_fwd_runner.hpp | 4 +-- include/cute/atom/mma_atom.hpp | 9 +++++- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index dd93e30814..68aec350e4 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -74,13 +74,11 @@ template + class TiledCopyQ_, class TiledCopyK_, class TiledCopyV_> struct FMHAFwdMainloop, CausalMask_, TiledMMAQK_, TiledMMAPV_, VTiles_, TensorQ_, TensorK_, TensorV_, - TiledCopyQ_, TiledCopyK_, TiledCopyV_, - SubgroupLayoutQK_> { + TiledCopyQ_, TiledCopyK_, TiledCopyV_> { // // Type Aliases // @@ -90,7 +88,7 @@ struct FMHAFwdMainloop, CausalMask_, using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); static constexpr int VTiles = VTiles_; using MmaAtomShapeQK = typename TiledMMAQK::AtomShape_MNK; - using SubgroupLayoutQK = SubgroupLayoutQK_; + using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk()); using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{})))); using TensorQ = TensorQ_; @@ -177,7 +175,6 @@ struct FMHAFwdMainloop, CausalMask_, int blk_k1, int thr_id, int seq_len, - int seq_coord, int full_tile_offset, int discard_seq_coord) { using namespace sycl::ext::oneapi::this_work_item; @@ -198,7 +195,7 @@ struct FMHAFwdMainloop, CausalMask_, Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d) Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) - + /* Partition global tensors into workgroup tiles */ Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) @@ -226,6 +223,8 @@ struct FMHAFwdMainloop, CausalMask_, auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D) auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K) + auto cS_thread = thr_mma_qk.partition_C(cP); + /* Create register fragments for MMA and copies */ auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_,_,0)); auto tSrQ = thr_mma_qk.partition_sg_fragment_A(gQ(_,_,0)); @@ -298,17 +297,13 @@ struct FMHAFwdMainloop, CausalMask_, /* Causal masking */ if constexpr (CausalMask) { if (K == blk_k1 - 1) { - int item_id = get_sub_group().get_local_id()[0]; - int base_col = item_id + K * get<1>(TileShapeQK{}); CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < shape<2>(tSrS.shape()); ++n) { - int col_idx = base_col + n * get<1>(MmaAtomShapeQK()); - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < shape<0>(tSrS.shape()); ++m) { - int row_idx = seq_coord + m; - if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { - tSrS(m, 0, n) = ElementS(-INFINITY); - } + for (int i = 0; i < tSrS.size(); ++i) { + // Need to get global col and row indices to mask the elements + int row_idx = get<0>(cS_thread(i)) + get<0>(blk_qv) * get<0>(TileShapeQK{}); + int col_idx = get<1>(cS_thread(i)) + K * get<1>(TileShapeQK{}); + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + tSrS(i) = ElementS(-INFINITY); } } } diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index e5f3d84a3e..56ad28ab9f 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -228,7 +228,7 @@ class XeFMHAFwdKernel { tArA, tA_max, tA_sum, blk_qv, 0, k_blocks, thr_id, seq_len, - seq_coord, full_tile_offset, discard_seq_coord); + full_tile_offset, discard_seq_coord); if constexpr (!is_empty_v && !is_empty_v) { sycl::group_barrier(get_work_group<3>()); } diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 25e147c954..11c39a1d78 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -576,9 +576,7 @@ struct FMHAConfig { MainloopDispatchPolicy, Causal, TiledMMAQK, TiledMMAPV, VTiles, TensorQ, TensorK, TensorV, - GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, - SubgroupLayoutQK - >; + GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV>; // Epilogue using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogue< diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 52044a3d20..a8de3fb701 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -225,17 +225,24 @@ struct TiledMMA : MMA_Atom using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); ThrLayoutVMNK thr_layout_vmnk_; + AtomLayoutMNK atom_layout_mnk_; CUTE_HOST_DEVICE constexpr TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {}) : MMA_Atom(mma_atom), - thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {} + thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)), + atom_layout_mnk_(thr_layout_mnk) {} CUTE_HOST_DEVICE constexpr auto get_thr_layout_vmnk() const { return thr_layout_vmnk_; } + CUTE_HOST_DEVICE constexpr auto + get_atom_layout_mnk() const { + return atom_layout_mnk_; + } + // Tile a tensor or a layout from shape // (M,N,...) // to shape From 21a1bceb06b4f5e327b0efc948e9e8ca3df34f06 Mon Sep 17 00:00:00 2001 From: "Chen, Xi2" Date: Mon, 10 Nov 2025 08:11:42 +0000 Subject: [PATCH 3/5] fix the template args Signed-off-by: Chen, Xi2 --- .../flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp | 5 ++--- examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp | 1 - examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp | 3 ++- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index 68aec350e4..85e99b64d0 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -62,8 +62,7 @@ template // Optional SubgroupLayout for QK + class TiledCopyV_ = void> // Optional TiledCopy for loading V struct FMHAFwdMainloop { static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; @@ -195,7 +194,7 @@ struct FMHAFwdMainloop, CausalMask_, Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d) Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) - + /* Partition global tensors into workgroup tiles */ Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp index c50717ea1d..e5f36e1589 100644 --- a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -147,7 +147,6 @@ int main(int argc, const char **argv) { #else constexpr int PipelineStages = 2; #endif - #ifdef IS_FLOAT_E5M2 using ElementQ = cutlass::float_e5m2_t; using ElementK = cutlass::float_e5m2_t; diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 11c39a1d78..63d2642fc3 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -576,7 +576,8 @@ struct FMHAConfig { MainloopDispatchPolicy, Causal, TiledMMAQK, TiledMMAPV, VTiles, TensorQ, TensorK, TensorV, - GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV>; + GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV + >; // Epilogue using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogue< From 76cd076030b55f12bb1b62872e000755e1fc26ba Mon Sep 17 00:00:00 2001 From: "Chen, Xi2" Date: Tue, 11 Nov 2025 08:45:11 +0000 Subject: [PATCH 4/5] refine the code Signed-off-by: Chen, Xi2 --- .../collective/xe_fmha_fwd_mainloop.hpp | 13 +++++++------ .../kernel/xe_fhma_fwd_kernel.hpp | 12 +++++++++--- .../06_bmg_flash_attention/xe_fmha_fwd_runner.hpp | 2 +- include/cute/atom/mma_atom.hpp | 6 ++---- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index 85e99b64d0..aab4c93696 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -86,7 +86,6 @@ struct FMHAFwdMainloop, CausalMask_, using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk()); using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); static constexpr int VTiles = VTiles_; - using MmaAtomShapeQK = typename TiledMMAQK::AtomShape_MNK; using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk()); using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{})))); @@ -195,6 +194,7 @@ struct FMHAFwdMainloop, CausalMask_, Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) + /* Partition global tensors into workgroup tiles */ Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) @@ -222,8 +222,6 @@ struct FMHAFwdMainloop, CausalMask_, auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D) auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K) - auto cS_thread = thr_mma_qk.partition_C(cP); - /* Create register fragments for MMA and copies */ auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_,_,0)); auto tSrQ = thr_mma_qk.partition_sg_fragment_A(gQ(_,_,0)); @@ -296,11 +294,14 @@ struct FMHAFwdMainloop, CausalMask_, /* Causal masking */ if constexpr (CausalMask) { if (K == blk_k1 - 1) { + // Need to get global col and row indices to mask the elements + Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len)); + Tensor gP = local_tile(cPgP, take<0,2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K)); + auto cS_thread = thr_mma_qk.partition_C(gP); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < tSrS.size(); ++i) { - // Need to get global col and row indices to mask the elements - int row_idx = get<0>(cS_thread(i)) + get<0>(blk_qv) * get<0>(TileShapeQK{}); - int col_idx = get<1>(cS_thread(i)) + K * get<1>(TileShapeQK{}); + int row_idx = get<0>(cS_thread(i)); + int col_idx = get<1>(cS_thread(i)); if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { tSrS(i) = ElementS(-INFINITY); } diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 56ad28ab9f..d91df21dd9 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -182,7 +182,13 @@ class XeFMHAFwdKernel { int thr_id = int(ThreadIdxX()); int sub_group_id = thr_id / intel::sg_size; - int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + + auto cS = make_identity_tensor(take<0,2>(TiledMMAQK{}.tile_mnk())); + auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); + auto q_offset_wi = get<0>(tScS(0)); + auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); + TileScheduler tile_scheduler{params.scheduler}; CUTLASS_PRAGMA_NO_UNROLL @@ -195,11 +201,11 @@ class XeFMHAFwdKernel { auto discard_seq_coord = s.seq_len_qo - offset; auto full_tile_offset = s.seq_len_kv - offset; - int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ)); + int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; - const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(s.seq_len_kv, seq_coord - discard_seq_coord) + SGTileQ : s.seq_len_kv; + const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(s.seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : s.seq_len_kv; const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch); diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 63d2642fc3..64c37ca4c7 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -577,7 +577,7 @@ struct FMHAConfig { TiledMMAQK, TiledMMAPV, VTiles, TensorQ, TensorK, TensorV, GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV - >; + >; // Epilogue using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogue< diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index a8de3fb701..4909aaae98 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -225,13 +225,11 @@ struct TiledMMA : MMA_Atom using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); ThrLayoutVMNK thr_layout_vmnk_; - AtomLayoutMNK atom_layout_mnk_; CUTE_HOST_DEVICE constexpr TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {}) : MMA_Atom(mma_atom), - thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)), - atom_layout_mnk_(thr_layout_mnk) {} + thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {} CUTE_HOST_DEVICE constexpr auto get_thr_layout_vmnk() const { @@ -240,7 +238,7 @@ struct TiledMMA : MMA_Atom CUTE_HOST_DEVICE constexpr auto get_atom_layout_mnk() const { - return atom_layout_mnk_; + return AtomLayoutMNK{}; } // Tile a tensor or a layout from shape From 92bd4cb247b6cca455771b7dfccf3eec52a14dbc Mon Sep 17 00:00:00 2001 From: "Chen, Xi2" Date: Tue, 11 Nov 2025 08:46:50 +0000 Subject: [PATCH 5/5] fix misc Signed-off-by: Chen, Xi2 --- .../flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index aab4c93696..eec7e3525c 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -194,7 +194,6 @@ struct FMHAFwdMainloop, CausalMask_, Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) - /* Partition global tensors into workgroup tiles */ Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D)