Skip to content

Commit 1c12b60

Browse files
committed
fix index & spelling
1 parent 5d64dfb commit 1c12b60

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ class XeFMHAFwdDynamicSplitKernel {
295295

296296
// Important: make sure multiple of 16 element for each copy
297297
// this is for storing partial results from different KV partitions
298-
static constexpr int num_elem_per_thead = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16;
298+
static constexpr int num_elem_per_thread = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16;
299299
// FIXME: maybe exceed more than 4 paritions???
300300
static const int max_num_partitions = 8;
301301

@@ -367,7 +367,7 @@ class XeFMHAFwdDynamicSplitKernel {
367367
const int wg_size = SGPerWG::value * intel::sg_size;
368368

369369
// partial attn outputs, exp sum and max logits
370-
ws_size += (max_num_partitions * num_batch_heads) * wg_size * num_elem_per_thead * sizeof(ElementA);
370+
ws_size += (max_num_partitions * num_batch_heads) * wg_size * num_elem_per_thread * sizeof(ElementA);
371371
// atomic counter
372372
ws_size += num_batch_heads * sizeof(int32_t);
373373
return ws_size;
@@ -549,21 +549,21 @@ class XeFMHAFwdDynamicSplitKernel {
549549
int partition_id = get_partition_id(wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks);
550550

551551
// store partial result: tArA, tA_max and tA_sum
552-
int offset = batch_head_id * max_num_partitions * num_elem_per_thead * SGPerWG::value * intel::sg_size
553-
+ partition_id * num_elem_per_thead * SGPerWG::value * intel::sg_size
554-
+ sg_id * intel::sg_size * num_elem_per_thead
555-
+ tid_in_sg * num_elem_per_thead;
556-
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{}));
557-
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{});
552+
int offset = batch_head_id * max_num_partitions * num_elem_per_thread * SGPerWG::value * intel::sg_size
553+
+ partition_id * num_elem_per_thread * SGPerWG::value * intel::sg_size
554+
+ sg_id * intel::sg_size * num_elem_per_thread
555+
+ tid_in_sg * num_elem_per_thread;
556+
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thread>{}));
557+
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread>{});
558558

559559
CUTLASS_PRAGMA_UNROLL
560560
for(int i = 0; i < size(FragA{}.shape()); ++i) {
561561
merged_res(i) = tArA(i);
562562
}
563563
CUTLASS_PRAGMA_UNROLL
564564
for (int i = 0; i < size(FragARow{}.shape()); ++i) {
565-
merged_res(i + size(FragA{}.shape())) = tA_max(i);
566-
merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i);
565+
merged_res(2 * i + size(FragA{}.shape())) = tA_max(i);
566+
merged_res(2 * i + 1 + size(FragA{}.shape())) = tA_sum(i);
567567
}
568568
copy(merged_res, tPartial);
569569

@@ -592,12 +592,12 @@ class XeFMHAFwdDynamicSplitKernel {
592592
clear(tA_sum);
593593

594594
for (int i = 0; i < num_partitions; ++i) {
595-
int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thead
596-
+ i * SGPerWG::value * intel::sg_size * num_elem_per_thead
597-
+ sg_id * intel::sg_size * num_elem_per_thead
598-
+ tid_in_sg * num_elem_per_thead;
599-
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{}));
600-
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{});
595+
int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thread
596+
+ i * SGPerWG::value * intel::sg_size * num_elem_per_thread
597+
+ sg_id * intel::sg_size * num_elem_per_thread
598+
+ tid_in_sg * num_elem_per_thread;
599+
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thread>{}));
600+
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread>{});
601601
copy(tPartial, merged_res);
602602

603603
if (i == 0) {
@@ -608,8 +608,8 @@ class XeFMHAFwdDynamicSplitKernel {
608608

609609
CUTLASS_PRAGMA_UNROLL
610610
for (int i = 0; i < size(FragARow{}.shape()); ++i) {
611-
tA_max(i) = merged_res(i + size(FragA{}.shape()));
612-
tA_sum(i) = merged_res(i + 1 + size(FragA{}.shape()));
611+
tA_max(i) = merged_res(2 * i + size(FragA{}.shape()));
612+
tA_sum(i) = merged_res(2 * i + 1 + size(FragA{}.shape()));
613613
}
614614

615615
continue;

examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ int main(int argc, const char **argv) {
112112
#define NUM_SG _16
113113
#define KV_TILE_SIZE _256
114114
#else
115-
#define NUM_SG _16
115+
#define NUM_SG _8
116116
#define KV_TILE_SIZE _512
117117
#endif
118118

0 commit comments

Comments
 (0)