Skip to content

Conversation

@wuxun-zhang
Copy link

The new kernel implements below method, key points are:

  • num of work groups are fixed to num of total XeCores
  • dynamically split KV seq length from all seqs into all work groups
  • each XeCore gets balanced work units
image

As of now there are two limitations:

  • only decode support (seq_len_qo==1)
  • batch_size * num_heads_q <= num of total XeCores

@pengzhao-intel
Copy link

maybe add the limitation of this algorithm in the code as well, especially for one with atomic.

Comment on lines +353 to +359
if (args.kernel.shape.seq_len_qo > 1) {
return false;
}
// current kernel only support num batch heads less than total XeCore count
if (args.kernel.shape.batch * args.kernel.shape.num_heads_q > args.hw_info.sm_count) {
return false;
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel Added checks here in can_implement().

@Antonyvance Antonyvance requested a review from Copilot November 5, 2025 07:29
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a persistent SDPA (Scaled Dot Product Attention) kernel for decode scenarios that implements dynamic load balancing across XeCores. The key innovation is fixing the number of work groups to match total XeCores and dynamically splitting KV sequence length across all work groups for balanced workload distribution.

Key changes:

  • New persistent tile scheduler (XeFHMAIndividualPersistentTileScheduler) that distributes work evenly across fixed XeCore count
  • New kernel implementation (XeFMHAFwdDynamicSplitKernel) with split-K reduction for partial results
  • Support infrastructure including atomic operations (atomicSub, atomicLoad) for synchronization

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
include/cutlass/gpu_generics.h Adds atomic operations (atomicSub, atomicLoad) for synchronization primitives
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp Integrates persistent kernel selection and queries hardware XeCore count
examples/06_bmg_flash_attention/CMakeLists.txt Adds build target for persistent kernel testing
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp Configures persistent kernel with appropriate tile sizes and subgroup layouts
applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp Implements persistent tile scheduler with dynamic work distribution
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp Implements dynamic split-K kernel with partial result reduction
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp Updates mainloop to use total block count for remainder masking

CUTLASS_DEVICE int atomicLoad(int *address) {
int result = 0;
#if defined(__SYCL_DEVICE_ONLY__)
auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]);
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The atomic_ref is constructed with address[0] which dereferences the pointer. This should be *address for clarity and consistency with standard atomic operations patterns.

Suggested change
auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]);
auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(*address);

Copilot uses AI. Check for mistakes.
Comment on lines 565 to 566
merged_res(i + size(FragA{}.shape())) = tA_max(i);
merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
merged_res(i + size(FragA{}.shape())) = tA_max(i);
merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i);
merged_res(2*i + size(FragA{}.shape())) = tA_max(i);
merged_res(2*i + 1 + size(FragA{}.shape())) = tA_sum(i);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. Updated.

#define NUM_SG _16
#define KV_TILE_SIZE _256
#else
#define NUM_SG _16
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to change the tile configuration for non-persistent mode?

Suggested change
#define NUM_SG _16
#define NUM_SG _8

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed back to _8 now.

@wuxun-zhang wuxun-zhang force-pushed the wuxun/persistent-sdpa branch 2 times, most recently from 038924d to 1c12b60 Compare November 6, 2025 02:21
@wuxun-zhang wuxun-zhang force-pushed the wuxun/persistent-sdpa branch from 1c12b60 to 532a50b Compare November 6, 2025 02:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants