Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,9 @@ function(add_webgpu_native_test test_name test_src)
endfunction()

if(EXECUTORCH_BUILD_WEBGPU_TEST)
add_webgpu_native_test(webgpu_native_test test/test_webgpu_native.cpp)
add_webgpu_native_test(
webgpu_dispatch_order_test test/native/test_dispatch_order.cpp
)
add_webgpu_native_test(
webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp
)
add_webgpu_native_test(
webgpu_update_cache_test test/native/test_update_cache.cpp
)

# Manifest-driven op-test framework: a generic gtest driver (webgpu_op_test) +
# its device-free util unit test. GTest needs -DEXECUTORCH_BUILD_TESTS=ON.
# All WebGPU native tests use GTest (device-dependent ones bring up the device
# in their own main(); the fold unit test is device-free via gtest_main).
# GTest needs -DEXECUTORCH_BUILD_TESTS=ON.
if(NOT TARGET GTest::gtest)
find_package(GTest QUIET)
endif()
Expand Down Expand Up @@ -194,6 +184,36 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
)
target_compile_options(webgpu_op_test_util_test PRIVATE -fexceptions)
set_property(TARGET webgpu_op_test_util_test PROPERTY CXX_STANDARD 17)

# Device-dependent native tests: each has its own main() that brings up the
# device once, then RUN_ALL_TESTS(); link GTest::gtest (not gtest_main).
add_webgpu_native_test(webgpu_native_test test/test_webgpu_native.cpp)
target_link_libraries(webgpu_native_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_dispatch_order_test test/native/test_dispatch_order.cpp
)
target_link_libraries(webgpu_dispatch_order_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp
)
target_link_libraries(webgpu_scratch_buffer_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_update_cache_test test/native/test_update_cache.cpp
)
target_link_libraries(webgpu_update_cache_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp
)
target_link_libraries(webgpu_dynamic_shape_test PRIVATE GTest::gtest)
add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp)
target_link_libraries(webgpu_index_test PRIVATE GTest::gtest)

# Device-free fold unit test (gtest_main provides main; no device needed).
add_webgpu_native_test(
webgpu_dispatch_2d_test test/native/test_dispatch_2d.cpp
)
target_link_libraries(
webgpu_dispatch_2d_test PRIVATE GTest::gtest GTest::gtest_main
)
endif()
add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp)
endif()
4 changes: 3 additions & 1 deletion backends/webgpu/runtime/WebGPUDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ WebGPUContext create_webgpu_context() {

// TimedWaitAny lets webgpu_wait() block on futures via wgpuInstanceWaitAny.
WGPUInstanceDescriptor instance_desc = {};
#if defined(__EMSCRIPTEN__)
// Vendored (buck) Dawn uses the older capabilities.* API; the rig's native
// Dawn and emscripten's emdawnwebgpu (emcc 4.0.19+) use requiredFeatures.
#if defined(WEBGPU_DAWN_INSTANCE_CAPABILITIES)
instance_desc.capabilities.timedWaitAnyEnable = true;
instance_desc.capabilities.timedWaitAnyMaxCount = 1;
#else
Expand Down
9 changes: 6 additions & 3 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,15 +814,15 @@ void WebGPUGraph::execute() {
wgpuComputePassEncoderSetBindGroup(
pass, 0, dispatch.bind_group, 0, nullptr);
wgpuComputePassEncoderDispatchWorkgroups(
pass, dispatch.workgroup_count_x, 1, 1);
pass, dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
#ifdef WGPU_BACKEND_ENABLE_PROFILING
if (qp) {
qp->record(
static_cast<uint32_t>(i),
dispatch.kernel_name,
{dispatch.workgroup_count_x, 1, 1},
{dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1},
{1, 1, 1});
}
#endif // WGPU_BACKEND_ENABLE_PROFILING
Expand Down Expand Up @@ -894,7 +894,10 @@ void WebGPUGraph::execute() {
wgpuComputePassEncoderSetBindGroup(
pass, 0, dispatches_[i].bind_group, 0, nullptr);
wgpuComputePassEncoderDispatchWorkgroups(
pass, dispatches_[i].workgroup_count_x, 1, 1);
pass,
dispatches_[i].workgroup_count_x,
dispatches_[i].workgroup_count_y,
1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
}
Expand Down
1 change: 1 addition & 0 deletions backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct WebGPUDispatch {
WGPUBindGroup bind_group = nullptr;
uint32_t workgroup_count_x = 1;
std::string kernel_name; // bench label
uint32_t workgroup_count_y = 1; // 2D fold (>65535); 1 = unchanged 1D path
// DMA copy command; default Compute keeps existing positional inits valid.
enum class Kind { Compute, Copy };
Kind kind = Kind::Compute;
Expand Down
64 changes: 57 additions & 7 deletions backends/webgpu/runtime/WebGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <webgpu/webgpu.h>

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <stdexcept>
Expand Down Expand Up @@ -47,27 +48,76 @@ inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) {
return desired;
}

struct WgCount {
uint32_t x;
uint32_t y;
};

// Device's max workgroups per dispatch dimension; the WebGPU spec-default floor
// (65535) if the query fails — never under-reports a real device's capacity.
inline uint32_t queried_max_workgroups(WGPUDevice device) {
WGPULimits limits = {};
return wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
limits.maxComputeWorkgroupsPerDimension > 0
? limits.maxComputeWorkgroupsPerDimension
: 65535u;
}

// Pure 2D fold of a 1D workgroup count (device-free, unit-testable): {count,1}
// when count <= max, else a near-square {x, y} with x ~ ceil(sqrt(count)) so
// the launched grid stays close to count. A flat {max, div_up(count, max)}
// split would leave up to ~half the workgroups inactive when count just exceeds
// max, and inactive workgroups still cost launch/scheduling; the near-square
// split keeps the waste to O(sqrt(count)). Throws if even a max*max grid is too
// small (a 3rd dispatch dimension, out of scope). The shader reconstructs the
// linear index from @builtin(num_workgroups), so any x/y factoring works.
inline WgCount fold_workgroup_count_2d(
uint32_t count,
uint32_t max_count,
const char* op_name) {
if (count <= max_count) {
return {count, 1u};
}
uint32_t x =
static_cast<uint32_t>(std::ceil(std::sqrt(static_cast<double>(count))));
x = std::min(x, max_count);
// ceil-div written overflow-safe (count >= 1 here) as count nears UINT32_MAX.
uint32_t y = 1u + (count - 1u) / x;
if (y > max_count) {
throw std::runtime_error(
std::string("WebGPU ") + op_name +
": workgroup count needs a 3rd dispatch dimension (unsupported)");
}
return {x, y};
}

// 1D dispatch count (mirrors Vulkan div_up); throws if > device limit.
inline uint32_t compute_1d_workgroup_count(
WGPUDevice device,
uint32_t num_threads,
uint32_t workgroup_size,
const char* op_name) {
uint32_t count = div_up(num_threads, workgroup_size);
WGPULimits limits = {};
uint32_t max_count =
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
limits.maxComputeWorkgroupsPerDimension > 0
? limits.maxComputeWorkgroupsPerDimension
: 65535u; // WebGPU spec-default floor
if (count > max_count) {
if (count > queried_max_workgroups(device)) {
throw std::runtime_error(
std::string("WebGPU ") + op_name +
": workgroup count exceeds the 1D dispatch limit");
}
return count;
}

// 2D dispatch count: fold the 1D count across x/y when it exceeds the per-dim
// limit (lifts the cap, e.g. for SDPA prefill). Same fast path as compute_1d.
inline WgCount compute_2d_workgroup_count(
WGPUDevice device,
uint32_t num_threads,
uint32_t workgroup_size,
const char* op_name) {
uint32_t count = div_up(num_threads, workgroup_size);
return fold_workgroup_count_2d(
count, queried_max_workgroups(device), op_name);
}

// Create a uniform buffer mapped-at-creation, copy `size` bytes in, and unmap.
inline WGPUBuffer
make_uniform(WGPUDevice device, const void* data, size_t size) {
Expand Down
60 changes: 29 additions & 31 deletions backends/webgpu/runtime/ops/add/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {

uint32_t wg_size =
utils::clamp_workgroup_size(device, kBinaryAddWorkgroupSizeX);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "add");
utils::WgCount workgroup_count =
utils::compute_2d_workgroup_count(device, num_elements, wg_size, "add");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
Expand Down Expand Up @@ -158,40 +158,38 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});
graph.add_dispatch(
{pipeline, bind_group, workgroup_count.x, "", workgroup_count.y});
const size_t dispatch_idx = graph.num_dispatches() - 1;

// Dynamic shapes: recompute numel/dispatch; out follows the larger operand.
WGPUBuffer params_buf = uniform_buffer;
auto add_resize = [in1_id,
in2_id,
out_id,
alpha,
wg_size,
dispatch_idx,
params_buf](WebGPUGraph& g) {
const auto& d1 = g.cur_dims(in1_id);
const auto& d2 = g.cur_dims(in2_id);
const uint64_t n1 = utils::numel_of(d1);
const uint64_t n2 = utils::numel_of(d2);
const uint64_t numel = n2 > n1 ? n2 : n1;
const uint64_t n_min = n2 > n1 ? n1 : n2;
// The flat add follows the larger operand and broadcasts the smaller; valid
// only when the smaller tiles evenly into it (rejects e.g. [4,1] vs [1,3],
// whose true [4,3] result this flat kernel cannot produce).
if (n_min == 0u || numel % n_min != 0u) {
throw std::runtime_error(
"add(resize): operands are not broadcast-compatible by numel");
}
g.set_cur_dims(out_id, n2 > n1 ? d2 : d1);
AddParams p = {};
p.num_elements = static_cast<uint32_t>(numel);
p.alpha = alpha;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
auto add_resize =
[in1_id, in2_id, out_id, alpha, wg_size, dispatch_idx, params_buf](
WebGPUGraph& g) {
const auto& d1 = g.cur_dims(in1_id);
const auto& d2 = g.cur_dims(in2_id);
const uint64_t n1 = utils::numel_of(d1);
const uint64_t n2 = utils::numel_of(d2);
const uint64_t numel = n2 > n1 ? n2 : n1;
const uint64_t n_min = n2 > n1 ? n1 : n2;
// The flat add follows the larger operand and broadcasts the smaller;
// valid only when the smaller tiles evenly into it (rejects e.g. [4,1]
// vs [1,3], whose true [4,3] result this flat kernel cannot produce).
if (n_min == 0u || numel % n_min != 0u) {
throw std::runtime_error(
"add(resize): operands are not broadcast-compatible by numel");
}
g.set_cur_dims(out_id, n2 > n1 ? d2 : d1);
AddParams p = {};
p.num_elements = static_cast<uint32_t>(numel);
p.alpha = alpha;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
g.device(), static_cast<uint32_t>(numel), wg_size, "add(resize)");
};
g.dispatch_at(dispatch_idx).workgroup_count_x = wgc.x;
g.dispatch_at(dispatch_idx).workgroup_count_y = wgc.y;
};
graph.add_tensor_resize_hook(in1_id, add_resize);
graph.add_tensor_resize_hook(in2_id, add_resize);

Expand Down
6 changes: 4 additions & 2 deletions backends/webgpu/runtime/ops/add/binary_add.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ struct Params {
override wg_size: u32 = 256;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= params.num_elements) {
return;
}
Expand Down
8 changes: 5 additions & 3 deletions backends/webgpu/runtime/ops/add/binary_add_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from binary_add.wgsl - DO NOT EDIT.
// wgsl-sha256: c1ceec80c8d4d3d56986ad91ce0d7f9a57cd8467b8c3aa07a28da70e51d141d9
// wgsl-sha256: e66bd67465c2a0296e09668df54f87605a4c91015a615f3734cdd0f140a74477
inline constexpr const char* kBinaryAddWGSL = R"(
@group(0) @binding(0) var<storage, read> input1: array<f32>;
@group(0) @binding(1) var<storage, read> input2: array<f32>;
Expand All @@ -28,8 +28,10 @@ struct Params {
override wg_size: u32 = 256;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= params.num_elements) {
return;
}
Expand Down
17 changes: 9 additions & 8 deletions backends/webgpu/runtime/ops/mul/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const auto& in2_tensor = graph.get_tensor(in2_id);
const auto& out_tensor = graph.get_tensor(out_id);

// Rank guard (NCHW backend is <= 4 dims; 1D dispatch only).
// Rank guard (NCHW backend is <= 4 dims).
if (out_tensor.dims.size() > kTensorMetaMaxNdim ||
in1_tensor.dims.size() > kTensorMetaMaxNdim ||
in2_tensor.dims.size() > kTensorMetaMaxNdim) {
Expand Down Expand Up @@ -63,8 +63,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {

uint32_t wg_size =
utils::clamp_workgroup_size(device, kBinaryMulWorkgroupSizeX);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, out_meta.numel, wg_size, "mul");
utils::WgCount workgroup_count =
utils::compute_2d_workgroup_count(device, out_meta.numel, wg_size, "mul");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
Expand Down Expand Up @@ -165,8 +165,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

const size_t dispatch_idx =
graph.add_dispatch({pipeline, bind_group, workgroup_count});
const size_t dispatch_idx = graph.add_dispatch(
{pipeline, bind_group, workgroup_count.x, "mul", workgroup_count.y});

// Dynamic shapes: rebuild all 3 broadcast TensorMeta UBOs + dispatch.
WGPUBuffer o_buf = out_meta_buf, a_buf = in1_meta_buf, b_buf = in2_meta_buf;
Expand Down Expand Up @@ -199,9 +199,10 @@ void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
wgpuQueueWriteBuffer(g.queue(), o_buf, 0, &om, sizeof(om));
wgpuQueueWriteBuffer(g.queue(), a_buf, 0, &am, sizeof(am));
wgpuQueueWriteBuffer(g.queue(), b_buf, 0, &bm, sizeof(bm));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(), om.numel, wg_size, "mul(resize)");
const utils::WgCount wgc = utils::compute_2d_workgroup_count(
g.device(), om.numel, wg_size, "mul(resize)");
g.dispatch_at(dispatch_idx).workgroup_count_x = wgc.x;
g.dispatch_at(dispatch_idx).workgroup_count_y = wgc.y;
};
graph.add_tensor_resize_hook(in1_id, mul_resize);
graph.add_tensor_resize_hook(in2_id, mul_resize);
Expand Down
7 changes: 5 additions & 2 deletions backends/webgpu/runtime/ops/mul/binary_mul.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ struct TensorMeta {
override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= out_meta.numel) {
return;
}
Expand Down
9 changes: 6 additions & 3 deletions backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from binary_mul.wgsl - DO NOT EDIT.
// wgsl-sha256: e7f77426cbaf48e6085e0d882522c027302ec97ef017b86a2275eed9820f7891
// wgsl-sha256: cca69c3428f37f293942637e23f664225dec81a56f184bcb63185b6629dd155e
inline constexpr const char* kBinaryMulWGSL = R"(
@group(0) @binding(0) var<storage, read> input1: array<f32>;
@group(0) @binding(1) var<storage, read> input2: array<f32>;
Expand All @@ -32,8 +32,11 @@ struct TensorMeta {
override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>) {
// 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel).
let idx = gid.x + gid.y * (num_workgroups.x * wg_size);
if (idx >= out_meta.numel) {
return;
}
Expand Down
Loading
Loading