Skip to content

Commit fb8c97c

Browse files
authored
Add CausalMask support with new flash attention api (#604)
Signed-off-by: Chen, Xi2 <[email protected]>
1 parent e2fde37 commit fb8c97c

File tree

4 files changed

+53
-13
lines changed

4 files changed

+53
-13
lines changed

applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
8686
using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk());
8787
using TileShapePV = decltype(TiledMMAPV{}.tile_mnk());
8888
static constexpr int VTiles = VTiles_;
89-
89+
using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk());
9090
using SGPerWG = decltype(product(take<1,4>(shape(typename TiledMMAQK::ThrLayoutVMNK{}))));
9191

9292
using TensorQ = TensorQ_;
@@ -171,8 +171,10 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
171171
QVCoord blk_qv, // WG tile indices: (Q,V)
172172
int blk_k0, // K block range: [K0,K1)
173173
int blk_k1,
174-
int thr_id) { // Work-item ID
175-
174+
int thr_id,
175+
int seq_len,
176+
int full_tile_offset,
177+
int discard_seq_coord) {
176178
using namespace sycl::ext::oneapi::this_work_item;
177179

178180
// Short dimension names:
@@ -266,7 +268,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
266268
}
267269

268270
/* Check if */
269-
bool check_remainder_k = (shape<0>(K_2D) % get<1>(TileShapeQK{}) != 0);
271+
bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0);
270272

271273
/* Main loop, blocked in k. */
272274
for (int K = blk_k0; K < blk_k1; K++) {
@@ -288,23 +290,37 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
288290
/* V prefetch for GEMM 2 */
289291
prefetch(prefetch_v, pVgV(_,_,_,K));
290292

293+
/* Causal masking */
294+
if constexpr (CausalMask) {
295+
if (K == blk_k1 - 1) {
296+
// Need to get global col and row indices to mask the elements
297+
Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len));
298+
Tensor gP = local_tile(cPgP, take<0,2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K));
299+
auto cS_thread = thr_mma_qk.partition_C(gP);
300+
CUTLASS_PRAGMA_UNROLL
301+
for (int i = 0; i < tSrS.size(); ++i) {
302+
int row_idx = get<0>(cS_thread(i));
303+
int col_idx = get<1>(cS_thread(i));
304+
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
305+
tSrS(i) = ElementS(-INFINITY);
306+
}
307+
}
308+
}
309+
}
291310
/* k masking for remainder tiles */
292311
if (check_remainder_k && K == blk_k1 - 1) {
293312
FragSRow k_rem_mask;
294313
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
295314
CUTLASS_PRAGMA_UNROLL
296315
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
297-
k_rem_mask(i) = (k < shape<0>(K_2D)) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
316+
k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
298317
}
299318
CUTLASS_PRAGMA_UNROLL
300319
for (int i = 0; i < tSrS.size(); i++) {
301320
tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i));
302321
}
303322
}
304323

305-
/* TODO: causal masking */
306-
static_assert(!CausalMask, "Causal mask unimplemented");
307-
308324
/* Apply softmax and scaling */
309325
softmax(K == 0, tSrS, tA_max, tA_sum, tArA);
310326
reorder(tSrS, tArP);

applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class XeFMHAFwdKernel {
7272
using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV;
7373
using TileShapeQK = typename CollectiveMainloop::TileShapeQK;
7474
using TileShapePV = typename CollectiveMainloop::TileShapePV;
75-
75+
using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK;
7676
using ElementQ = typename CollectiveMainloop::TensorQ::element_type;
7777
using ElementK = typename CollectiveMainloop::TensorK::element_type;
7878
using ElementV = typename CollectiveMainloop::TensorV::element_type;
@@ -181,6 +181,13 @@ class XeFMHAFwdKernel {
181181
int head_group_q = s.num_heads_q / s.num_heads_kv;
182182

183183
int thr_id = int(ThreadIdxX());
184+
int sub_group_id = thr_id / intel::sg_size;
185+
int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})));
186+
187+
auto cS = make_identity_tensor(take<0,2>(TiledMMAQK{}.tile_mnk()));
188+
auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS);
189+
auto q_offset_wi = get<0>(tScS(0));
190+
auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0);
184191

185192
TileScheduler tile_scheduler{params.scheduler};
186193

@@ -190,7 +197,16 @@ class XeFMHAFwdKernel {
190197
auto blk_qv = make_coord(blk_q, blk_v);
191198
int head = head_q / head_group_q;
192199

193-
const int k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{}));
200+
auto offset = cute::min(s.seq_len_qo, s.seq_len_kv);
201+
auto discard_seq_coord = s.seq_len_qo - offset;
202+
auto full_tile_offset = s.seq_len_kv - offset;
203+
204+
int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg));
205+
206+
if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue;
207+
208+
const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(s.seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : s.seq_len_kv;
209+
const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{}));
194210

195211
auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch);
196212
auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch);
@@ -217,8 +233,8 @@ class XeFMHAFwdKernel {
217233
V(_,_,head,idx_b),
218234
tArA, tA_max, tA_sum,
219235
blk_qv, 0, k_blocks,
220-
thr_id);
221-
236+
thr_id, seq_len,
237+
full_tile_offset, discard_seq_coord);
222238
if constexpr (!is_empty_v<MainloopSharedStorage> && !is_empty_v<EpilogueSharedStorage>) {
223239
sycl::group_barrier(get_work_group<3>());
224240
}

examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,8 @@ int main(int argc, const char **argv) {
160160
using ElementK = bfloat16_t;
161161
using ElementV = bfloat16_t;
162162
#endif
163-
return FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options);
163+
164+
return options.is_causal ? FMHAConfig<true, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options)
165+
: FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options);
166+
164167
}

include/cute/atom/mma_atom.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ struct TiledMMA : MMA_Atom
236236
return thr_layout_vmnk_;
237237
}
238238

239+
CUTE_HOST_DEVICE constexpr auto
240+
get_atom_layout_mnk() const {
241+
return AtomLayoutMNK{};
242+
}
243+
239244
// Tile a tensor or a layout from shape
240245
// (M,N,...)
241246
// to shape

0 commit comments

Comments
 (0)