-
Notifications
You must be signed in to change notification settings - Fork 68
Persistent SDPA kernel #608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
maybe add the limitation of this algorithm in the code as well, especially for one with atomic. |
| if (args.kernel.shape.seq_len_qo > 1) { | ||
| return false; | ||
| } | ||
| // current kernel only support num batch heads less than total XeCore count | ||
| if (args.kernel.shape.batch * args.kernel.shape.num_heads_q > args.hw_info.sm_count) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengzhao-intel Added checks here in can_implement().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a persistent SDPA (Scaled Dot Product Attention) kernel for decode scenarios that implements dynamic load balancing across XeCores. The key innovation is fixing the number of work groups to match total XeCores and dynamically splitting KV sequence length across all work groups for balanced workload distribution.
Key changes:
- New persistent tile scheduler (
XeFHMAIndividualPersistentTileScheduler) that distributes work evenly across fixed XeCore count - New kernel implementation (
XeFMHAFwdDynamicSplitKernel) with split-K reduction for partial results - Support infrastructure including atomic operations (
atomicSub,atomicLoad) for synchronization
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
include/cutlass/gpu_generics.h |
Adds atomic operations (atomicSub, atomicLoad) for synchronization primitives |
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp |
Integrates persistent kernel selection and queries hardware XeCore count |
examples/06_bmg_flash_attention/CMakeLists.txt |
Adds build target for persistent kernel testing |
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp |
Configures persistent kernel with appropriate tile sizes and subgroup layouts |
applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp |
Implements persistent tile scheduler with dynamic work distribution |
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp |
Implements dynamic split-K kernel with partial result reduction |
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp |
Updates mainloop to use total block count for remainder masking |
| CUTLASS_DEVICE int atomicLoad(int *address) { | ||
| int result = 0; | ||
| #if defined(__SYCL_DEVICE_ONLY__) | ||
| auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The atomic_ref is constructed with address[0] which dereferences the pointer. This should be *address for clarity and consistency with standard atomic operations patterns.
| auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]); | |
| auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(*address); |
| merged_res(i + size(FragA{}.shape())) = tA_max(i); | ||
| merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| merged_res(i + size(FragA{}.shape())) = tA_max(i); | |
| merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i); | |
| merged_res(2*i + size(FragA{}.shape())) = tA_max(i); | |
| merged_res(2*i + 1 + size(FragA{}.shape())) = tA_sum(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch. Updated.
| #define NUM_SG _16 | ||
| #define KV_TILE_SIZE _256 | ||
| #else | ||
| #define NUM_SG _16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to change the tile configuration for non-persistent mode?
| #define NUM_SG _16 | |
| #define NUM_SG _8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed back to _8 now.
038924d to
1c12b60
Compare
1c12b60 to
532a50b
Compare
The new kernel implements below method, key points are:
As of now there are two limitations:
seq_len_qo==1)batch_size * num_heads_q <= num of total XeCores