Skip to content

Commit d84f370

Browse files
committed
[ExecuTorch][WebGPU] Enable FlashDecoding by default for decode SDPA (runtime shape gate)
Pull Request resolved: #20544 **Makes split-KV FlashDecoding the default decode-attention path** (it was shipped dormant behind a default-OFF compile flag). FD is the fastest WebGPU SDPA decode arm (**+178% vs naive**, M4 Pro, isolated op); this turns it on for production and selects it at runtime by a shape-capability predicate. {F1991715077} **Problem:** the FD kernel is correct and measured (+178%) but compile-gated OFF, so no production build used it. A device-limit gate (web-llm-style `maxStorageBufferBindingSize`) was considered but is dead code here: FD's resource needs (workgroup size 64, 512 B shared memory, 5 storage bindings) are all below WebGPU's baseline minimum limits, and FD binds the same K/V caches as the materialized fallback — so no spec-compliant device can run materialized decode but fail FD. The only selection criterion with real effect is shape. **Solution:** enable FD by default and select it at runtime on shape, not device. - **Before:** `EXECUTORCH_BUILD_WEBGPU_SDPA_FD` default OFF; FD code unlinked; every decode used the materialized QK/softmax/AV path. - **After:** flag default ON (kept as a build-time kill-switch); decode (`S == 1`, static input_pos) with head dim `<= kSdpaFdMaxHeadDim` uses FD; other shapes (including head dim > 128) fall through to the materialized path. **Implementation:** - `Sdpa.cpp`: extend the FD selection predicate with `D <= kSdpaFdMaxHeadDim` so unsupported head dims fall through instead of throwing. - `SdpaFdDecode.h`: expose `kSdpaFdMaxHeadDim` (FD's lane-owns-D reach) as the single source of truth; `SdpaFdDecode.cpp` ties it to `WG_SIZE * MAX_D_PER_LANE` with a `static_assert`. - `CMakeLists.txt` (fbcode + xplat): flip the option default to ON; OFF remains a kill-switch that drops all FlashDecoding code. - `test_webgpu_native_ci.sh`: drop the now-redundant explicit `=ON` flag so CI builds and tests the default. - Mirrors Vulkan `backends/vulkan/runtime/graph/ops/impl/SDPA.cpp` shape-based kernel selection (`is_single_token`); no device-adaptive gate, matching the Vulkan delegate. **Constraints:** decode-only (`S == 1`), static input_pos (dynamic-pos decode still uses the materialized path); fp32, buffer-only; the FD kernels are unchanged by this diff. Co-authored with Claude Code. ghstack-source-id: 397454762 @exported-using-ghexport Differential Revision: [D109520722](https://our.internmc.facebook.com/intern/diff/D109520722/)
1 parent 1227757 commit d84f370

8 files changed

Lines changed: 768 additions & 3 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ if(EXECUTORCH_BUILD_WEBGPU_PROFILING)
9999
)
100100
endif()
101101

102+
# Split-KV FlashDecoding decode path (sdpa_fd_decode). Default ON: selected at
103+
# runtime for decode (S==1) shapes it supports (head dim <= kSdpaFdMaxHeadDim);
104+
# other shapes use the materialized SDPA path. Set OFF as a kill-switch to drop
105+
# all FlashDecoding code from the build.
106+
option(EXECUTORCH_BUILD_WEBGPU_SDPA_FD
107+
"Enable split-KV FlashDecoding SDPA decode path" ON
108+
)
109+
if(EXECUTORCH_BUILD_WEBGPU_SDPA_FD)
110+
target_sources(
111+
webgpu_backend PRIVATE runtime/ops/sdpa_fd_decode/SdpaFdDecode.cpp
112+
)
113+
target_compile_definitions(webgpu_backend PRIVATE WEBGPU_SDPA_FD)
114+
endif()
115+
102116
# Link with --whole-archive for static registration of backend + ops
103117
executorch_target_link_options_shared_lib(webgpu_backend)
104118

backends/webgpu/runtime/ops/sdpa/Sdpa.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#include <executorch/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h>
1313
#include <executorch/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h>
1414
#include <executorch/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h>
15+
#if defined(WEBGPU_SDPA_FD)
16+
#include <executorch/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.h>
17+
#endif
1518
#include <executorch/backends/webgpu/runtime/ops/update_cache/update_cache_wgsl.h>
1619

1720
#include <webgpu/webgpu.h>
@@ -427,9 +430,6 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
427430
static_cast<uint64_t>(S) *
428431
static_cast<uint64_t>(dynamic_pos ? Cmax : context_len);
429432
const uint64_t aw_bytes = aw_cap_floats * sizeof(float);
430-
// Prefill scratch scales as Hq·S·Cmax; can be large for long-context prefill.
431-
WGPUBuffer attn_weights = graph.create_scratch_buffer(aw_bytes);
432-
WGPUBuffer attn_weights_softmax = graph.create_scratch_buffer(aw_bytes);
433433

434434
// Dynamic input_pos: the resize hook rewrites these per step.
435435
WGPUBuffer uc_k_buf = nullptr, uc_v_buf = nullptr, qk_buf = nullptr,
@@ -473,6 +473,20 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
473473
dynamic_pos,
474474
"update_cache(V)");
475475

476+
#ifdef WEBGPU_SDPA_FD
477+
// FlashDecoding decode (S==1, static pos). Shapes FD can't handle (head dim
478+
// > kSdpaFdMaxHeadDim) fall through to the materialized path below.
479+
if (S == 1 && !dynamic_pos && D <= kSdpaFdMaxHeadDim) {
480+
sdpa_fd_decode_dispatch(
481+
graph, q, k_cache, v_cache, out, Hq, Hkv, D, context_len, g, scale);
482+
return;
483+
}
484+
#endif
485+
486+
// QK/softmax scratch — allocated only on the non-FD path (Hq*S*Cmax prefill).
487+
WGPUBuffer attn_weights = graph.create_scratch_buffer(aw_bytes);
488+
WGPUBuffer attn_weights_softmax = graph.create_scratch_buffer(aw_bytes);
489+
476490
// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
477491
{
478492
if (aw_floats > UINT32_MAX) {
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Split-KV FlashDecoding decode dispatch (split + reduce passes).
10+
11+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
12+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
13+
#include <executorch/backends/webgpu/runtime/ops/sdpa_fd_decode/SdpaFdDecode.h>
14+
#include <executorch/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_reduce_wgsl.h>
15+
#include <executorch/backends/webgpu/runtime/ops/sdpa_fd_decode/sdpa_fd_split_wgsl.h>
16+
17+
#include <webgpu/webgpu.h>
18+
19+
#include <cstdint>
20+
#include <cstring>
21+
#include <stdexcept>
22+
#include <string>
23+
24+
namespace executorch::backends::webgpu {
25+
26+
namespace {
27+
28+
// MUST match the .wgsl: MAX_SPLITS and WG_SIZE*MAX_D_PER_LANE.
29+
constexpr uint32_t kSdpaFdSplitTile = 64; // KV positions per split
30+
constexpr uint32_t kSdpaFdMaxSplits = 128; // == MAX_SPLITS in both .wgsl files
31+
// Public head-dim limit (kSdpaFdMaxHeadDim) must equal the kernel's lane-owns-D
32+
// reach; tie them so a WG_SIZE change can't silently desync the Sdpa.cpp gate.
33+
static_assert(
34+
kSdpaFdMaxHeadDim == kSdpaFdSplitWorkgroupSizeX * 2u,
35+
"kSdpaFdMaxHeadDim must match WG_SIZE * MAX_D_PER_LANE");
36+
37+
struct FdSplitParams {
38+
uint32_t _pad0; // 16B-alignment pad (head index derived from workgroup_id)
39+
uint32_t Hkv;
40+
uint32_t D;
41+
uint32_t context_len;
42+
uint32_t g;
43+
uint32_t num_splits;
44+
uint32_t split_len;
45+
float scale;
46+
};
47+
static_assert(sizeof(FdSplitParams) == 32, "FdSplitParams must be 32B");
48+
49+
struct FdReduceParams {
50+
uint32_t D;
51+
uint32_t num_splits;
52+
uint32_t _pad0;
53+
uint32_t _pad1;
54+
};
55+
static_assert(sizeof(FdReduceParams) == 16, "FdReduceParams must be 16B");
56+
57+
struct BufferBinding {
58+
WGPUBuffer buffer;
59+
uint64_t size;
60+
};
61+
62+
WGPUBuffer
63+
make_uniform_buffer(WebGPUGraph& graph, const void* data, size_t size) {
64+
WGPUDevice device = graph.device();
65+
WGPUBufferDescriptor desc = {};
66+
desc.size = size;
67+
desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
68+
desc.mappedAtCreation = true;
69+
WGPUBuffer buffer = wgpuDeviceCreateBuffer(device, &desc);
70+
void* mapped = wgpuBufferGetMappedRange(buffer, 0, size);
71+
std::memcpy(mapped, data, size);
72+
wgpuBufferUnmap(buffer);
73+
graph.add_uniform_buffer_bytes(size);
74+
return buffer;
75+
}
76+
77+
// Mirrors Sdpa.cpp build_dispatch; n_rw leading bindings are read_write.
78+
void build_dispatch(
79+
WebGPUGraph& graph,
80+
const char* wgsl_source,
81+
const BufferBinding* storage_bindings,
82+
uint32_t n_storage,
83+
uint32_t n_rw,
84+
WGPUBuffer uniform_buffer,
85+
uint64_t uniform_size,
86+
uint32_t workgroup_count_x,
87+
const char* kernel_name) {
88+
WGPUDevice device = graph.device();
89+
90+
WGPUShaderSourceWGSL wgsl_desc = {};
91+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
92+
wgsl_desc.code = {wgsl_source, WGPU_STRLEN};
93+
WGPUShaderModuleDescriptor shader_desc = {};
94+
shader_desc.nextInChain = &wgsl_desc.chain;
95+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
96+
97+
constexpr uint32_t kMaxEntries = 8;
98+
if (n_storage + 1u > kMaxEntries) {
99+
throw std::runtime_error(
100+
"WebGPU sdpa FlashDecoding: bind group entry count exceeds kMaxEntries");
101+
}
102+
WGPUBindGroupLayoutEntry bgl_entries[kMaxEntries] = {};
103+
const uint32_t uniform_binding = n_storage;
104+
for (uint32_t i = 0; i < n_storage; i++) {
105+
bgl_entries[i].binding = i;
106+
bgl_entries[i].visibility = WGPUShaderStage_Compute;
107+
bgl_entries[i].buffer.type = (i < n_rw)
108+
? WGPUBufferBindingType_Storage
109+
: WGPUBufferBindingType_ReadOnlyStorage;
110+
}
111+
bgl_entries[uniform_binding].binding = uniform_binding;
112+
bgl_entries[uniform_binding].visibility = WGPUShaderStage_Compute;
113+
bgl_entries[uniform_binding].buffer.type = WGPUBufferBindingType_Uniform;
114+
115+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
116+
bgl_desc.entryCount = n_storage + 1;
117+
bgl_desc.entries = bgl_entries;
118+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
119+
120+
WGPUPipelineLayoutDescriptor pl_desc = {};
121+
pl_desc.bindGroupLayoutCount = 1;
122+
pl_desc.bindGroupLayouts = &bgl;
123+
WGPUPipelineLayout pipeline_layout =
124+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
125+
126+
WGPUComputePipelineDescriptor pipeline_desc = {};
127+
pipeline_desc.layout = pipeline_layout;
128+
pipeline_desc.compute.module = shader;
129+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
130+
WGPUComputePipeline pipeline =
131+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
132+
133+
WGPUBindGroupEntry bg_entries[kMaxEntries] = {};
134+
for (uint32_t i = 0; i < n_storage; i++) {
135+
bg_entries[i].binding = i;
136+
bg_entries[i].buffer = storage_bindings[i].buffer;
137+
bg_entries[i].size = storage_bindings[i].size;
138+
}
139+
bg_entries[uniform_binding].binding = uniform_binding;
140+
bg_entries[uniform_binding].buffer = uniform_buffer;
141+
bg_entries[uniform_binding].size = uniform_size;
142+
143+
WGPUBindGroupDescriptor bg_desc = {};
144+
bg_desc.layout = bgl;
145+
bg_desc.entryCount = n_storage + 1;
146+
bg_desc.entries = bg_entries;
147+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
148+
149+
graph.add_dispatch({pipeline, bind_group, workgroup_count_x, kernel_name});
150+
151+
wgpuShaderModuleRelease(shader);
152+
wgpuBindGroupLayoutRelease(bgl);
153+
wgpuPipelineLayoutRelease(pipeline_layout);
154+
wgpuBufferRelease(uniform_buffer);
155+
}
156+
157+
} // namespace
158+
159+
void sdpa_fd_decode_dispatch(
160+
WebGPUGraph& graph,
161+
const WebGPUTensor& q,
162+
const WebGPUTensor& k_cache,
163+
const WebGPUTensor& v_cache,
164+
const WebGPUTensor& out,
165+
int64_t Hq,
166+
int64_t Hkv,
167+
int64_t D,
168+
int64_t context_len,
169+
int64_t g,
170+
float scale) {
171+
// Defensive contract guard: the Sdpa.cpp gate only routes D <= this here, but
172+
// keep the check (lane-owns-D reach) so a future caller can't silently overrun.
173+
if (D > kSdpaFdMaxHeadDim) {
174+
throw std::runtime_error(
175+
"WebGPU sdpa FlashDecoding: head dim must be <= " +
176+
std::to_string(kSdpaFdMaxHeadDim));
177+
}
178+
if (D % 4 != 0) {
179+
throw std::runtime_error(
180+
"WebGPU sdpa FlashDecoding: head dim must be a multiple of 4");
181+
}
182+
// context_len 0 -> split_len 0 -> empty KV loop -> silent zero output; the
183+
// Sdpa.cpp gate guarantees ctx >= 1, but fail loud if called directly.
184+
if (context_len <= 0) {
185+
throw std::runtime_error(
186+
"WebGPU sdpa FlashDecoding: context_len must be positive");
187+
}
188+
189+
// Split factor: one split per kSdpaFdSplitTile KV rows, capped.
190+
uint32_t num_splits = static_cast<uint32_t>(
191+
(context_len + kSdpaFdSplitTile - 1) / kSdpaFdSplitTile);
192+
if (num_splits > kSdpaFdMaxSplits) {
193+
num_splits = kSdpaFdMaxSplits;
194+
}
195+
const uint32_t split_len =
196+
static_cast<uint32_t>((context_len + num_splits - 1) / num_splits);
197+
198+
// Scratch: per-(head,split) partials at kSdpaFdMaxSplits stride.
199+
const uint64_t po_floats = static_cast<uint64_t>(Hq) *
200+
static_cast<uint64_t>(kSdpaFdMaxSplits) * static_cast<uint64_t>(D);
201+
const uint64_t pml_floats = static_cast<uint64_t>(Hq) *
202+
static_cast<uint64_t>(kSdpaFdMaxSplits) * 2ull;
203+
WGPUBuffer part_o = graph.create_scratch_buffer(po_floats * sizeof(float));
204+
WGPUBuffer part_ml = graph.create_scratch_buffer(pml_floats * sizeof(float));
205+
206+
// Pass 1: split (Hq*num_splits WGs) -> writes part_o, part_ml.
207+
FdSplitParams sp = {};
208+
sp.Hkv = static_cast<uint32_t>(Hkv);
209+
sp.D = static_cast<uint32_t>(D);
210+
sp.context_len = static_cast<uint32_t>(context_len);
211+
sp.g = static_cast<uint32_t>(g);
212+
sp.num_splits = num_splits;
213+
sp.split_len = split_len;
214+
sp.scale = scale;
215+
WGPUBuffer ub_split = make_uniform_buffer(graph, &sp, sizeof(sp));
216+
BufferBinding split_bindings[5] = {
217+
{part_o, po_floats * sizeof(float)},
218+
{part_ml, pml_floats * sizeof(float)},
219+
{q.buffer, q.nbytes},
220+
{k_cache.buffer, k_cache.nbytes},
221+
{v_cache.buffer, v_cache.nbytes}};
222+
// Compute the thread product in 64-bit + guard before the u32 cast, mirroring
223+
// the Sdpa.cpp aw_floats > UINT32_MAX guards.
224+
const uint64_t split_threads = static_cast<uint64_t>(Hq) *
225+
static_cast<uint64_t>(num_splits) *
226+
static_cast<uint64_t>(kSdpaFdSplitWorkgroupSizeX);
227+
if (split_threads > UINT32_MAX) {
228+
throw std::runtime_error(
229+
"WebGPU sdpa FlashDecoding: split thread count exceeds uint32 max");
230+
}
231+
const uint32_t wgc_split = utils::compute_1d_workgroup_count(
232+
graph.device(),
233+
static_cast<uint32_t>(split_threads),
234+
kSdpaFdSplitWorkgroupSizeX,
235+
"fd_split");
236+
build_dispatch(
237+
graph,
238+
kSdpaFdSplitWGSL,
239+
split_bindings,
240+
5,
241+
2,
242+
ub_split,
243+
sizeof(sp),
244+
wgc_split,
245+
"fd_split");
246+
247+
// Pass 2: reduce (Hq WGs) -> reads part_o, part_ml; writes out.
248+
FdReduceParams rp = {};
249+
rp.D = static_cast<uint32_t>(D);
250+
rp.num_splits = num_splits;
251+
WGPUBuffer ub_reduce = make_uniform_buffer(graph, &rp, sizeof(rp));
252+
BufferBinding reduce_bindings[3] = {
253+
{out.buffer, out.nbytes},
254+
{part_o, po_floats * sizeof(float)},
255+
{part_ml, pml_floats * sizeof(float)}};
256+
const uint32_t wgc_reduce = utils::compute_1d_workgroup_count(
257+
graph.device(),
258+
static_cast<uint32_t>(Hq) * kSdpaFdReduceWorkgroupSizeX,
259+
kSdpaFdReduceWorkgroupSizeX,
260+
"fd_reduce");
261+
build_dispatch(
262+
graph,
263+
kSdpaFdReduceWGSL,
264+
reduce_bindings,
265+
3,
266+
1,
267+
ub_reduce,
268+
sizeof(rp),
269+
wgc_reduce,
270+
"fd_reduce");
271+
}
272+
273+
} // namespace executorch::backends::webgpu
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
12+
13+
#include <cstdint>
14+
15+
namespace executorch::backends::webgpu {
16+
17+
// FlashDecoding's lane-owns-D layout covers head dims up to WG_SIZE(64) *
18+
// MAX_D_PER_LANE(2). Decode shapes above this fall through to the materialized
19+
// SDPA path (the FD selection predicate in Sdpa.cpp checks this).
20+
constexpr int64_t kSdpaFdMaxHeadDim = 128;
21+
22+
// Split-KV FlashDecoding decode dispatch (S==1): a split pass over
23+
// Hq*num_splits workgroups + a reduce pass over Hq workgroups. Called from the
24+
// Sdpa.cpp WEBGPU_SDPA_FD branch.
25+
void sdpa_fd_decode_dispatch(
26+
WebGPUGraph& graph,
27+
const WebGPUTensor& q,
28+
const WebGPUTensor& k_cache,
29+
const WebGPUTensor& v_cache,
30+
const WebGPUTensor& out,
31+
int64_t Hq,
32+
int64_t Hkv,
33+
int64_t D,
34+
int64_t context_len,
35+
int64_t g,
36+
float scale);
37+
38+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)