@@ -295,7 +295,7 @@ class XeFMHAFwdDynamicSplitKernel {
295295
296296 // Important: make sure multiple of 16 element for each copy
297297 // this is for storing partial results from different KV partitions
298- static constexpr int num_elem_per_thead = (size(FragA{}.shape()) + 2 * size (FragARow{}.shape()) + 15) / 16 * 16;
298+ static constexpr int num_elem_per_thread = (size(FragA{}.shape()) + 2 * size (FragARow{}.shape()) + 15) / 16 * 16;
299299 // FIXME: maybe exceed more than 4 paritions???
300300 static const int max_num_partitions = 8 ;
301301
@@ -367,7 +367,7 @@ class XeFMHAFwdDynamicSplitKernel {
367367 const int wg_size = SGPerWG::value * intel::sg_size;
368368
369369 // partial attn outputs, exp sum and max logits
370- ws_size += (max_num_partitions * num_batch_heads) * wg_size * num_elem_per_thead * sizeof (ElementA);
370+ ws_size += (max_num_partitions * num_batch_heads) * wg_size * num_elem_per_thread * sizeof (ElementA);
371371 // atomic counter
372372 ws_size += num_batch_heads * sizeof (int32_t );
373373 return ws_size;
@@ -549,21 +549,21 @@ class XeFMHAFwdDynamicSplitKernel {
549549 int partition_id = get_partition_id (wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks);
550550
551551 // store partial result: tArA, tA_max and tA_sum
552- int offset = batch_head_id * max_num_partitions * num_elem_per_thead * SGPerWG::value * intel::sg_size
553- + partition_id * num_elem_per_thead * SGPerWG::value * intel::sg_size
554- + sg_id * intel::sg_size * num_elem_per_thead
555- + tid_in_sg * num_elem_per_thead ;
556- Tensor tPartial = make_tensor (params.partial_results_ptr + offset, make_shape (Int<num_elem_per_thead >{}));
557- Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead >{});
552+ int offset = batch_head_id * max_num_partitions * num_elem_per_thread * SGPerWG::value * intel::sg_size
553+ + partition_id * num_elem_per_thread * SGPerWG::value * intel::sg_size
554+ + sg_id * intel::sg_size * num_elem_per_thread
555+ + tid_in_sg * num_elem_per_thread ;
556+ Tensor tPartial = make_tensor (params.partial_results_ptr + offset, make_shape (Int<num_elem_per_thread >{}));
557+ Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread >{});
558558
559559 CUTLASS_PRAGMA_UNROLL
560560 for (int i = 0 ; i < size (FragA{}.shape ()); ++i) {
561561 merged_res (i) = tArA (i);
562562 }
563563 CUTLASS_PRAGMA_UNROLL
564564 for (int i = 0 ; i < size (FragARow{}.shape ()); ++i) {
565- merged_res (i + size (FragA{}.shape ())) = tA_max (i);
566- merged_res (i + 1 + size (FragA{}.shape ())) = tA_sum (i);
565+ merged_res (2 * i + size (FragA{}.shape ())) = tA_max (i);
566+ merged_res (2 * i + 1 + size (FragA{}.shape ())) = tA_sum (i);
567567 }
568568 copy (merged_res, tPartial);
569569
@@ -592,12 +592,12 @@ class XeFMHAFwdDynamicSplitKernel {
592592 clear (tA_sum);
593593
594594 for (int i = 0 ; i < num_partitions; ++i) {
595- int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thead
596- + i * SGPerWG::value * intel::sg_size * num_elem_per_thead
597- + sg_id * intel::sg_size * num_elem_per_thead
598- + tid_in_sg * num_elem_per_thead ;
599- Tensor tPartial = make_tensor (params.partial_results_ptr + offset, make_shape (Int<num_elem_per_thead >{}));
600- Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead >{});
595+ int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thread
596+ + i * SGPerWG::value * intel::sg_size * num_elem_per_thread
597+ + sg_id * intel::sg_size * num_elem_per_thread
598+ + tid_in_sg * num_elem_per_thread ;
599+ Tensor tPartial = make_tensor (params.partial_results_ptr + offset, make_shape (Int<num_elem_per_thread >{}));
600+ Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread >{});
601601 copy (tPartial, merged_res);
602602
603603 if (i == 0 ) {
@@ -608,8 +608,8 @@ class XeFMHAFwdDynamicSplitKernel {
608608
609609 CUTLASS_PRAGMA_UNROLL
610610 for (int i = 0 ; i < size (FragARow{}.shape ()); ++i) {
611- tA_max (i) = merged_res (i + size (FragA{}.shape ()));
612- tA_sum (i) = merged_res (i + 1 + size (FragA{}.shape ()));
611+ tA_max (i) = merged_res (2 * i + size (FragA{}.shape ()));
612+ tA_sum (i) = merged_res (2 * i + 1 + size (FragA{}.shape ()));
613613 }
614614
615615 continue ;
0 commit comments