Skip to content

Commit 5d64dfb

Browse files
committed
update tile scheduler & add runtime check
1 parent 252b0d1 commit 5d64dfb

File tree

3 files changed

+30
-61
lines changed

3 files changed

+30
-61
lines changed

applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,14 @@ class XeFMHAFwdDynamicSplitKernel {
349349
}
350350

351351
static bool can_implement(Arguments const &args) {
352+
// current kernel only support decode
353+
if (args.kernel.shape.seq_len_qo > 1) {
354+
return false;
355+
}
356+
// current kernel only support num batch heads less than total XeCore count
357+
if (args.kernel.shape.batch * args.kernel.shape.num_heads_q > args.hw_info.sm_count) {
358+
return false;
359+
}
352360
return CollectiveMainloop::can_implement(args.mainloop)
353361
&& CollectiveEpilogue::can_implement(args.epilogue);
354362
}
@@ -436,8 +444,6 @@ class XeFMHAFwdDynamicSplitKernel {
436444
out1(i) = out1(i) * broadcast<0>(rescale1, out1, i) + out2(i) * broadcast<0>(rescale2, out2, i);
437445
}
438446

439-
#define DEBUG_PRINT 0
440-
441447
CUTLASS_DEVICE
442448
void operator()(Params const &params, char *smem_buf)
443449
{
@@ -456,25 +462,19 @@ class XeFMHAFwdDynamicSplitKernel {
456462
int tid_in_sg = thr_id % intel::sg_size;
457463
int num_batch_heads = s.batch * s.num_heads_q;
458464

459-
TileScheduler tile_scheduler{params.scheduler};
460-
461465
int local_k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{}));
462466
// total number of blocks need to be processed across all wgs
463467
int total_k_blocks = local_k_blocks * num_batch_heads;
464468
// to guarantee all wg process similar number of blocks of KV
465469
int num_blocks_per_wg = cute::ceil_div(total_k_blocks, GridDimZ());
466470

467-
#if DEBUG_PRINT
468-
if (thr_id == 0 && wg_id == 0) {
469-
cute::print("Debug>> total_k_blocks: %d, num_blocks_per_wg: %d, local_k_blocks: %d, num_batch_heads: %d\n",
470-
total_k_blocks, num_blocks_per_wg, local_k_blocks, num_batch_heads);
471-
}
472-
#endif
471+
TileScheduler tile_scheduler{params.scheduler, get<1>(TileShapeQK{}), local_k_blocks, num_batch_heads};
473472

474473
CUTLASS_PRAGMA_NO_UNROLL
475474
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
476475
// head_q, idx_b from tile scheduler will not be used
477-
auto [blk_q, blk_v, head_q_unused, idx_b_unused] = tile_scheduler.get_block_coord(); // (Q,V,h,b)
476+
// auto [blk_q, blk_v, head_q_unused, idx_b_unused] = tile_scheduler.get_block_coord(); // (Q,V,h,b)
477+
auto [blk_q, blk_v, start_batch_head_id] = tile_scheduler.get_block_coord(); // (Q,V, batch_head_idx)
478478
auto blk_qv = make_coord(blk_q, blk_v);
479479

480480
auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch);
@@ -495,23 +495,13 @@ class XeFMHAFwdDynamicSplitKernel {
495495
FragA tArA;
496496
FragARow tA_max, tA_sum;
497497

498-
// compute start/end batch head id for current wg
499-
int start_batch_head_id = wg_id * num_blocks_per_wg / local_k_blocks;
500-
501498
// compute num computed blocks for start batch head id
502499
int num_computed_blocks = (start_batch_head_id == 0) ? (wg_id * num_blocks_per_wg) : (wg_id * num_blocks_per_wg - start_batch_head_id * local_k_blocks);
503500
int start_blk, end_blk, head_q, idx_b, head_kv;
504501
// leader wg is also responsible for reducing partial results, while other
505502
// worker wg only to compute partial results
506503
bool is_leader_wg = wg_id < num_batch_heads;
507504

508-
#if DEBUG_PRINT
509-
if (thr_id == 0) {
510-
cute::print("Debug>> wg id %d, start_batch_head_id: %d, num_computed_blocks: %d\n",
511-
wg_id, start_batch_head_id, num_computed_blocks);
512-
}
513-
#endif
514-
515505
if (thr_id == 0 && is_leader_wg) {
516506
// reset atomic counter before computation
517507
*(params.atomic_reduce_cnt_ptr + wg_id) = 0;
@@ -558,13 +548,6 @@ class XeFMHAFwdDynamicSplitKernel {
558548
// partition id of start batch head id in current wg
559549
int partition_id = get_partition_id(wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks);
560550

561-
#if DEBUG_PRINT
562-
if (thr_id == 0) {
563-
cute::print("Debug>> wg id %d, batch_head_id: %d, partition_id: %d\n",
564-
wg_id, batch_head_id, partition_id);
565-
}
566-
#endif
567-
568551
// store partial result: tArA, tA_max and tA_sum
569552
int offset = batch_head_id * max_num_partitions * num_elem_per_thead * SGPerWG::value * intel::sg_size
570553
+ partition_id * num_elem_per_thead * SGPerWG::value * intel::sg_size
@@ -601,12 +584,6 @@ class XeFMHAFwdDynamicSplitKernel {
601584
if (is_leader_wg) {
602585
int num_partitions = get_num_partitions(wg_id, num_blocks_per_wg, local_k_blocks);
603586

604-
#if DEBUG_PRINT
605-
if (thr_id == 0) {
606-
cute::print("Debug>> wg id %d, num_partitions: %d\n", wg_id, num_partitions);
607-
}
608-
#endif
609-
610587
// check atomic to wait for partial results ready
611588
while(atomicLoad(params.atomic_reduce_cnt_ptr + wg_id) != num_partitions) {}
612589

applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,15 @@ struct XeFHMAIndividualPersistentTileScheduler {
101101

102102
bool valid_ = true;
103103
Params params;
104+
int kv_tile_size_;
105+
// num of kv blocks for each head
106+
int local_num_kv_blocks_;
107+
int num_batch_heads_;
104108

105109
CUTLASS_DEVICE
106-
XeFHMAIndividualPersistentTileScheduler(Params const& params) : params(params) {}
110+
XeFHMAIndividualPersistentTileScheduler(Params const& params, int kv_tile_size,
111+
int local_num_kv_blocks, int num_batch_heads)
112+
: params(params), kv_tile_size_(kv_tile_size), local_num_kv_blocks_(local_num_kv_blocks), num_batch_heads_(num_batch_heads) {}
107113

108114
template <class ProblemShape, class TileShape>
109115
static Params to_underlying_arguments(
@@ -116,31 +122,8 @@ struct XeFHMAIndividualPersistentTileScheduler {
116122
size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q
117123
size(shape.batch * shape.num_heads_q)); // (h,b) -- split later
118124
int num_heads = shape.num_heads_q;
119-
120-
auto total_wg = grid.x * grid.y * grid.z;
121-
// FIXME: replace with runtime check
122-
assert(shape.batch == 1);
123-
assert((grid.z <= hw_info.sm_count / 2) && "XeFHMAIndividualPersistentTileScheduler only enabled for decode case where num batch heads samller than SM count");
124-
125-
// how many partitions each KV seq is split into
126-
int num_partitions = hw_info.sm_count / grid.z;
127-
// this is for the case where sm_count cannot be divisible by num_batch_heads,
128-
// for some head/work group, the KV seq need to split into `num_partitions+1`
129-
// partitions to occupy all xecores, here we assme first `tail_wg` work groups
130-
// will handle one more partition
131-
// for eample, num head is 8, sm_count is 20, so first 20%8=4 work groups
132-
// will handle 3 partitions, the rest 4 work groups will handle 2 partitions
133-
int num_tail_wg = hw_info.sm_count % grid.z;
134-
135-
// assume grid shape (1, 1, hw_info.sm_count) to use all xecores
136125
grid.z = hw_info.sm_count;
137-
// int num_partitions = 4; // for 5/1
138-
// grid.z *= num_partitions;
139-
// num_heads *= num_partitions;
140-
141-
// FIXME: add fallback mechanism if given problem size doesn't meet requirement
142126

143-
std::cout << "Debug>> grid shape [" << grid.x << ", " << grid.y << ", " << grid.z << "]\n";
144127
return Params{grid, {num_heads}};
145128
}
146129

@@ -157,10 +140,18 @@ struct XeFHMAIndividualPersistentTileScheduler {
157140
CUTLASS_DEVICE
158141
auto get_block_coord() {
159142
using namespace cute;
160-
int idx_b = BlockIdxZ();
143+
int wg_id = BlockIdxZ();
161144
int head;
162-
params.divmod_num_heads(idx_b, head, idx_b);
163-
return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b);
145+
146+
// total number of blocks need to be processed across all wgs
147+
int total_num_kv_blocks = local_num_kv_blocks_ * num_batch_heads_;
148+
// guarantee all wg process similar number of blocks of KV (load balance)
149+
int num_blocks_per_wg = cute::ceil_div(total_num_kv_blocks, GridDimZ());
150+
151+
// compute start batch head id for current wg
152+
int start_batch_head_id = wg_id * num_blocks_per_wg / local_num_kv_blocks_;
153+
154+
return make_coord(BlockIdxY(), BlockIdxX(), start_batch_head_id);
164155
}
165156

166157
CUTLASS_DEVICE

examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ int main(int argc, const char **argv) {
113113
#define KV_TILE_SIZE _256
114114
#else
115115
#define NUM_SG _16
116+
#define KV_TILE_SIZE _512
116117
#endif
117118

118119
#if HEAD_DIM == 16

0 commit comments

Comments
 (0)