Skip to content

Commit 99ec14f

Browse files
committed
[ExecuTorch][WebGPU] GPU timestamp query profiling for SDPA
Pull Request resolved: #20167 SDPA-specific instrumentation layered on the general GPU-timestamp infrastructure (companion diff below): tag each fused SDPA dispatch with its `kernel_name` so the `WebGPUQueryPool` can attribute on-GPU time to the attention stage that produced it. `sdpa_with_kv_cache` runs four chained dispatches — `update_cache` -> QK (`attn_weights`) -> softmax -> AV (`compute_out`); `WebGPUGraph::execute()` brackets each compute pass with a timestamp when the pool is active, and this diff labels each dispatch so the per-pass durations map back to the right stage. Opt-in via the `WEBGPU_TIMESTAMP_QUERY` env var; off by default, so the production `execute()` path is byte-identical. This is the per-kernel hook a forthcoming SDPA kernel benchmark will read; the benchmark itself (and any comparative numbers) is a separate follow-up. Co-authored with Claude. ghstack-source-id: 392093463 @exported-using-ghexport Differential Revision: [D107678235](https://our.internmc.facebook.com/intern/diff/D107678235/)
1 parent b43121c commit 99ec14f

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

  • backends/webgpu/runtime/ops/sdpa

backends/webgpu/runtime/ops/sdpa/Sdpa.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ void build_dispatch(
156156
uint64_t uniform_size,
157157
uint32_t workgroup_count_x,
158158
uint32_t wg_size,
159-
bool retain_uniform = false) {
159+
bool retain_uniform = false,
160+
const char* kernel_name = "") {
160161
WGPUDevice device = graph.device();
161162

162163
WGPUShaderSourceWGSL wgsl_desc = {};
@@ -227,7 +228,7 @@ void build_dispatch(
227228
bg_desc.entries = bg_entries;
228229
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
229230

230-
graph.add_dispatch({pipeline, bind_group, workgroup_count_x});
231+
graph.add_dispatch({pipeline, bind_group, workgroup_count_x, kernel_name});
231232

232233
wgpuShaderModuleRelease(shader);
233234
wgpuBindGroupLayoutRelease(bgl);
@@ -269,7 +270,8 @@ static WGPUBuffer record_update_cache_dispatch(
269270
sizeof(uc),
270271
wgc,
271272
uc_wg,
272-
dynamic_pos);
273+
dynamic_pos,
274+
"update_cache");
273275
return ubuf;
274276
}
275277

@@ -473,7 +475,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
473475
sizeof(p),
474476
wgc,
475477
qk_wg,
476-
dynamic_pos);
478+
dynamic_pos,
479+
"sdpa_compute_attn_weights");
477480
qk_buf = ubuf;
478481
qk_idx = graph.num_dispatches() - 1;
479482
}
@@ -496,7 +499,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
496499
sizeof(p),
497500
wgc,
498501
0,
499-
dynamic_pos);
502+
dynamic_pos,
503+
"sdpa_softmax");
500504
softmax_buf = ubuf;
501505
}
502506

@@ -521,7 +525,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
521525
sizeof(p),
522526
wgc,
523527
av_wg,
524-
dynamic_pos);
528+
dynamic_pos,
529+
"sdpa_compute_out");
525530
av_buf = ubuf;
526531
}
527532

0 commit comments

Comments
 (0)