Skip to content

Commit b2dc380

Browse files
committed
[ExecuTorch][WebGPU] 2D-fold mul + permute dispatch (lift 65535 1D cap)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
1 parent db65e54 commit b2dc380

6 files changed

Lines changed: 34 additions & 20 deletions

File tree

backends/webgpu/runtime/ops/mul/BinaryOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
3434
const auto& in2_tensor = graph.get_tensor(in2_id);
3535
const auto& out_tensor = graph.get_tensor(out_id);
3636

37-
// Rank guard (NCHW backend is <= 4 dims; 1D dispatch only).
37+
// Rank guard (NCHW backend is <= 4 dims).
3838
if (out_tensor.dims.size() > kTensorMetaMaxNdim ||
3939
in1_tensor.dims.size() > kTensorMetaMaxNdim ||
4040
in2_tensor.dims.size() > kTensorMetaMaxNdim) {
@@ -63,8 +63,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
6363

6464
uint32_t wg_size =
6565
utils::clamp_workgroup_size(device, kBinaryMulWorkgroupSizeX);
66-
uint32_t workgroup_count =
67-
utils::compute_1d_workgroup_count(device, out_meta.numel, wg_size, "mul");
66+
utils::WgCount workgroup_count =
67+
utils::compute_2d_workgroup_count(device, out_meta.numel, wg_size, "mul");
6868

6969
WGPUConstantEntry wg_size_constant = {};
7070
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
@@ -165,8 +165,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
165165
bg_desc.entries = bg_entries;
166166
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
167167

168-
const size_t dispatch_idx =
169-
graph.add_dispatch({pipeline, bind_group, workgroup_count});
168+
const size_t dispatch_idx = graph.add_dispatch(
169+
{pipeline, bind_group, workgroup_count.x, "mul", workgroup_count.y});
170170

171171
// Dynamic shapes: rebuild all 3 broadcast TensorMeta UBOs + dispatch.
172172
WGPUBuffer o_buf = out_meta_buf, a_buf = in1_meta_buf, b_buf = in2_meta_buf;
@@ -199,9 +199,10 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
199199
wgpuQueueWriteBuffer(g.queue(), o_buf, 0, &om, sizeof(om));
200200
wgpuQueueWriteBuffer(g.queue(), a_buf, 0, &am, sizeof(am));
201201
wgpuQueueWriteBuffer(g.queue(), b_buf, 0, &bm, sizeof(bm));
202-
g.dispatch_at(dispatch_idx).workgroup_count_x =
203-
utils::compute_1d_workgroup_count(
204-
g.device(), om.numel, wg_size, "mul(resize)");
202+
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
203+
g.device(), om.numel, wg_size, "mul(resize)");
204+
g.dispatch_at(dispatch_idx).workgroup_count_x = wgc.x;
205+
g.dispatch_at(dispatch_idx).workgroup_count_y = wgc.y;
205206
};
206207
graph.add_tensor_resize_hook(in1_id, mul_resize);
207208
graph.add_tensor_resize_hook(in2_id, mul_resize);

backends/webgpu/runtime/ops/mul/binary_mul.wgsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ struct TensorMeta {
1515
override wg_size: u32 = 64u;
1616

1717
@compute @workgroup_size(wg_size, 1, 1)
18-
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
19-
let idx = gid.x;
18+
fn main(
19+
@builtin(global_invocation_id) gid: vec3<u32>,
20+
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
21+
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
22+
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
2023
if (idx >= out_meta.numel) {
2124
return;
2225
}

backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from binary_mul.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: e7f77426cbaf48e6085e0d882522c027302ec97ef017b86a2275eed9820f7891
16+
// wgsl-sha256: cca69c3428f37f293942637e23f664225dec81a56f184bcb63185b6629dd155e
1717
inline constexpr const char* kBinaryMulWGSL = R"(
1818
@group(0) @binding(0) var<storage, read> input1: array<f32>;
1919
@group(0) @binding(1) var<storage, read> input2: array<f32>;
@@ -32,8 +32,11 @@ struct TensorMeta {
3232
override wg_size: u32 = 64u;
3333
3434
@compute @workgroup_size(wg_size, 1, 1)
35-
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
36-
let idx = gid.x;
35+
fn main(
36+
@builtin(global_invocation_id) gid: vec3<u32>,
37+
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
38+
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
39+
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
3740
if (idx >= out_meta.numel) {
3841
return;
3942
}

backends/webgpu/runtime/ops/permute/Permute.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ void permute_impl(WebGPUGraph& graph, const std::vector<int>& args) {
9292

9393
uint32_t wg_size =
9494
utils::clamp_workgroup_size(device, kPermuteWorkgroupSizeX);
95-
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
95+
utils::WgCount workgroup_count = utils::compute_2d_workgroup_count(
9696
device, out_meta.numel, wg_size, "permute");
9797

9898
WGPUConstantEntry wg_size_constant = {};
@@ -176,7 +176,8 @@ void permute_impl(WebGPUGraph& graph, const std::vector<int>& args) {
176176
bg_desc.entries = bg_entries;
177177
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
178178

179-
graph.add_dispatch({pipeline, bind_group, workgroup_count});
179+
graph.add_dispatch(
180+
{pipeline, bind_group, workgroup_count.x, "permute", workgroup_count.y});
180181

181182
wgpuShaderModuleRelease(shader);
182183
wgpuBindGroupLayoutRelease(bgl);

backends/webgpu/runtime/ops/permute/permute.wgsl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ struct Params {
1818
override wg_size: u32 = 64u;
1919

2020
@compute @workgroup_size(wg_size, 1, 1)
21-
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22-
let out_bufi = gid.x;
21+
fn main(
22+
@builtin(global_invocation_id) gid: vec3<u32>,
23+
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
24+
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
25+
let out_bufi = gid.x + gid.y * (num_workgroups.x * wg_size);
2326
if (out_bufi >= out_meta.numel) {
2427
return;
2528
}

backends/webgpu/runtime/ops/permute/permute_wgsl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from permute.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: d34f59730cda7317589b6ed5691a1ccab8666b9c94e17ac2cb3658b036300197
16+
// wgsl-sha256: 05884aeb14426c979ea037b066266d8cab11f4fed76ee21ee8778e7fc13ad84e
1717
inline constexpr const char* kPermuteWGSL = R"(
1818
@group(0) @binding(0) var<storage, read> input: array<f32>;
1919
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@@ -35,8 +35,11 @@ struct Params {
3535
override wg_size: u32 = 64u;
3636
3737
@compute @workgroup_size(wg_size, 1, 1)
38-
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
39-
let out_bufi = gid.x;
38+
fn main(
39+
@builtin(global_invocation_id) gid: vec3<u32>,
40+
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
41+
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
42+
let out_bufi = gid.x + gid.y * (num_workgroups.x * wg_size);
4043
if (out_bufi >= out_meta.numel) {
4144
return;
4245
}

0 commit comments

Comments
 (0)