Skip to content

Commit fd16aad

Browse files
committed
[ExecuTorch][WebGPU] Dynamic resize hooks for add and mul
Pull Request resolved: #20577 **Make the elementwise add and mul ops serve any live shape from one graph.** **Problem:** `aten.add.Tensor` and `aten.mul.Tensor` baked their element count + param UBO(s) + output shape at `build()` for the max shape. On a dynamic-shape graph at a smaller live shape they would over-dispatch and leave the output sized at the max. **Solution:** - Before: one fixed dispatch sized for the build-time shape. - After: each registers a resize hook on BOTH operands (the dynamic one may be either operand by arg order). The hook recomputes the live element count, rewrites the param UBO(s), updates the dispatch `workgroup_count_x`, and sets the output `cur_dims`. Inert until an operand is resized. **Implementation:** - `add`: out follows the larger operand (robust when one input is a static residual and the other is the dynamic-S tensor); rewrites `AddParams`. - `mul`: recomputes the broadcast output shape and rebuilds all three `TensorMeta` UBOs via `fill_tensor_meta_broadcast`. - Each keeps its uniform buffer(s) alive via `own_uniform_buffer` instead of releasing at build. - Mirrors Vulkan per-op `resize_*_node` (recompute sizes + dispatch each execute). **Constraints:** Behavior-neutral on static graphs (the hook fires only when an operand's live shape differs from the max). No kernel/WGSL/numerics change. Co-authored-with: Claude Code. ghstack-source-id: 399812828 @exported-using-ghexport Differential Revision: [D109906093](https://our.internmc.facebook.com/intern/diff/D109906093/)
1 parent ae22389 commit fd16aad

2 files changed

Lines changed: 82 additions & 7 deletions

File tree

backends/webgpu/runtime/ops/add/BinaryOp.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,48 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
159159
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
160160

161161
graph.add_dispatch({pipeline, bind_group, workgroup_count});
162+
const size_t dispatch_idx = graph.num_dispatches() - 1;
163+
164+
// Dynamic shapes: recompute numel/dispatch; out follows the larger operand.
165+
WGPUBuffer params_buf = uniform_buffer;
166+
auto add_resize = [in1_id,
167+
in2_id,
168+
out_id,
169+
alpha,
170+
wg_size,
171+
dispatch_idx,
172+
params_buf](WebGPUGraph& g) {
173+
const auto& d1 = g.cur_dims(in1_id);
174+
const auto& d2 = g.cur_dims(in2_id);
175+
const uint64_t n1 = utils::numel_of(d1);
176+
const uint64_t n2 = utils::numel_of(d2);
177+
const uint64_t numel = n2 > n1 ? n2 : n1;
178+
const uint64_t n_min = n2 > n1 ? n1 : n2;
179+
// The flat add follows the larger operand and broadcasts the smaller; valid
180+
// only when the smaller tiles evenly into it (rejects e.g. [4,1] vs [1,3],
181+
// whose true [4,3] result this flat kernel cannot produce).
182+
if (n_min == 0u || numel % n_min != 0u) {
183+
throw std::runtime_error(
184+
"add(resize): operands are not broadcast-compatible by numel");
185+
}
186+
g.set_cur_dims(out_id, n2 > n1 ? d2 : d1);
187+
AddParams p = {};
188+
p.num_elements = static_cast<uint32_t>(numel);
189+
p.alpha = alpha;
190+
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
191+
g.dispatch_at(dispatch_idx).workgroup_count_x =
192+
utils::compute_1d_workgroup_count(
193+
g.device(), static_cast<uint32_t>(numel), wg_size, "add(resize)");
194+
};
195+
graph.add_tensor_resize_hook(in1_id, add_resize);
196+
graph.add_tensor_resize_hook(in2_id, add_resize);
162197

163198
// Release intermediate objects (pipeline and bind_group are kept by dispatch)
164199
wgpuShaderModuleRelease(shader);
165200
wgpuBindGroupLayoutRelease(bgl);
166201
wgpuPipelineLayoutRelease(pipeline_layout);
167-
// Drop our ref; the bind group keeps the uniform buffer alive until release.
168-
wgpuBufferRelease(uniform_buffer);
202+
// Graph owns it so a resize hook can rewrite it; freed in the dtor.
203+
graph.own_uniform_buffer(uniform_buffer);
169204
}
170205

171206
} // namespace

backends/webgpu/runtime/ops/mul/BinaryOp.cpp

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <webgpu/webgpu.h>
1616

17+
#include <algorithm>
1718
#include <stdexcept>
1819
#include <vector>
1920

@@ -164,15 +165,54 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
164165
bg_desc.entries = bg_entries;
165166
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
166167

167-
graph.add_dispatch({pipeline, bind_group, workgroup_count});
168+
const size_t dispatch_idx =
169+
graph.add_dispatch({pipeline, bind_group, workgroup_count});
170+
171+
// Dynamic shapes: rebuild all 3 broadcast TensorMeta UBOs + dispatch.
172+
WGPUBuffer o_buf = out_meta_buf, a_buf = in1_meta_buf, b_buf = in2_meta_buf;
173+
auto mul_resize =
174+
[in1_id, in2_id, out_id, wg_size, dispatch_idx, o_buf, a_buf, b_buf](
175+
WebGPUGraph& g) {
176+
const auto& a = g.cur_dims(in1_id);
177+
const auto& b = g.cur_dims(in2_id);
178+
const size_t r = std::max(a.size(), b.size());
179+
std::vector<int64_t> out_d(r, 1);
180+
for (size_t i = 0; i < r; i++) {
181+
const int64_t av = (i + a.size() < r) ? 1 : a[i - (r - a.size())];
182+
const int64_t bv = (i + b.size() < r) ? 1 : b[i - (r - b.size())];
183+
if (av != bv && av != 1 && bv != 1) {
184+
throw std::runtime_error(
185+
"mul(resize): operands are not broadcast-compatible");
186+
}
187+
out_d[i] = av > bv ? av : bv;
188+
}
189+
g.set_cur_dims(out_id, out_d);
190+
const uint32_t out_ndim = static_cast<uint32_t>(r);
191+
WebGPUTensor ta, tb, to;
192+
ta.dims = a;
193+
tb.dims = b;
194+
to.dims = out_d;
195+
TensorMeta om, am, bm;
196+
fill_tensor_meta_broadcast(to, out_ndim, &om);
197+
fill_tensor_meta_broadcast(ta, out_ndim, &am);
198+
fill_tensor_meta_broadcast(tb, out_ndim, &bm);
199+
wgpuQueueWriteBuffer(g.queue(), o_buf, 0, &om, sizeof(om));
200+
wgpuQueueWriteBuffer(g.queue(), a_buf, 0, &am, sizeof(am));
201+
wgpuQueueWriteBuffer(g.queue(), b_buf, 0, &bm, sizeof(bm));
202+
g.dispatch_at(dispatch_idx).workgroup_count_x =
203+
utils::compute_1d_workgroup_count(
204+
g.device(), om.numel, wg_size, "mul(resize)");
205+
};
206+
graph.add_tensor_resize_hook(in1_id, mul_resize);
207+
graph.add_tensor_resize_hook(in2_id, mul_resize);
168208

169209
wgpuShaderModuleRelease(shader);
170210
wgpuBindGroupLayoutRelease(bgl);
171211
wgpuPipelineLayoutRelease(pipeline_layout);
172-
// Drop our refs; the bind group keeps the uniforms alive until release.
173-
wgpuBufferRelease(out_meta_buf);
174-
wgpuBufferRelease(in1_meta_buf);
175-
wgpuBufferRelease(in2_meta_buf);
212+
// Graph owns them so a resize hook can rewrite them; freed in the dtor.
213+
graph.own_uniform_buffer(out_meta_buf);
214+
graph.own_uniform_buffer(in1_meta_buf);
215+
graph.own_uniform_buffer(in2_meta_buf);
176216
}
177217

178218
} // namespace

0 commit comments

Comments
 (0)