Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,9 @@ function(add_webgpu_native_test test_name test_src)
endfunction()

if(EXECUTORCH_BUILD_WEBGPU_TEST)
add_webgpu_native_test(webgpu_native_test test/test_webgpu_native.cpp)
add_webgpu_native_test(
webgpu_dispatch_order_test test/native/test_dispatch_order.cpp
)
add_webgpu_native_test(
webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp
)
add_webgpu_native_test(
webgpu_update_cache_test test/native/test_update_cache.cpp
)

# Manifest-driven op-test framework: a generic gtest driver (webgpu_op_test) +
# its device-free util unit test. GTest needs -DEXECUTORCH_BUILD_TESTS=ON.
# All WebGPU native tests use GTest (device-dependent ones bring up the device
# in their own main(); the fold unit test is device-free via gtest_main).
# GTest needs -DEXECUTORCH_BUILD_TESTS=ON.
if(NOT TARGET GTest::gtest)
find_package(GTest QUIET)
endif()
Expand Down Expand Up @@ -194,6 +184,36 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
)
target_compile_options(webgpu_op_test_util_test PRIVATE -fexceptions)
set_property(TARGET webgpu_op_test_util_test PROPERTY CXX_STANDARD 17)

# Device-dependent native tests: each has its own main() that brings up the
# device once, then RUN_ALL_TESTS(); link GTest::gtest (not gtest_main).
add_webgpu_native_test(webgpu_native_test test/test_webgpu_native.cpp)
target_link_libraries(webgpu_native_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_dispatch_order_test test/native/test_dispatch_order.cpp
)
target_link_libraries(webgpu_dispatch_order_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_scratch_buffer_test test/native/test_scratch_buffer.cpp
)
target_link_libraries(webgpu_scratch_buffer_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_update_cache_test test/native/test_update_cache.cpp
)
target_link_libraries(webgpu_update_cache_test PRIVATE GTest::gtest)
add_webgpu_native_test(
webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp
)
target_link_libraries(webgpu_dynamic_shape_test PRIVATE GTest::gtest)
add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp)
target_link_libraries(webgpu_index_test PRIVATE GTest::gtest)

# Device-free fold unit test (gtest_main provides main; no device needed).
add_webgpu_native_test(
webgpu_dispatch_2d_test test/native/test_dispatch_2d.cpp
)
target_link_libraries(
webgpu_dispatch_2d_test PRIVATE GTest::gtest GTest::gtest_main
)
endif()
add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp)
endif()
38 changes: 31 additions & 7 deletions backends/webgpu/runtime/WebGPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/platform/log.h>

#include <vector>

#include <new>

namespace executorch {
Expand All @@ -35,6 +38,7 @@ using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::register_backend;
using executorch::runtime::resize_tensor;
using executorch::runtime::Result;
using executorch::runtime::Span;

Expand Down Expand Up @@ -100,19 +104,39 @@ Error WebGPUBackend::execute(
// Copy inputs from EValue tensors to GPU buffers
std::vector<InputData> inputs;
inputs.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; i++) {
const auto& tensor = args[i]->toTensor();
const bool host_is_int64 =
tensor.scalar_type() == executorch::aten::ScalarType::Long;
inputs.push_back({tensor.const_data_ptr(), tensor.nbytes(), host_is_int64});
}
// Fail loud as a runtime Error so a throw never crosses the backend boundary.
try {
// Build the input list and, for dynamic shapes, shrink each input to its
// live sizes before upload (mirrors Vulkan maybe_resize_input). No-op when
// unchanged, so a static graph is byte-identical.
for (size_t i = 0; i < num_inputs; i++) {
const auto& tensor = args[i]->toTensor();
const bool host_is_int64 =
tensor.scalar_type() == executorch::aten::ScalarType::Long;
inputs.push_back(
{tensor.const_data_ptr(), tensor.nbytes(), host_is_int64});
const auto sizes = tensor.sizes();
std::vector<int64_t> new_dims(sizes.begin(), sizes.end());
graph->resize_input(graph->input_ids()[i], new_dims);
}
graph->copy_inputs(inputs);
graph->update_symints_from_inputs(inputs);
graph->propagate_resize();
// Resize each output EValue to its live shape so the readback length is
// correct (mirrors Vulkan maybe_resize_output).
for (size_t i = 0; i < num_outputs; i++) {
const auto& cd = graph->cur_dims(graph->output_ids()[i]);
std::vector<executorch::aten::SizesType> osizes(cd.begin(), cd.end());
Error e = resize_tensor(
args[num_inputs + i]->toTensor(),
ArrayRef<executorch::aten::SizesType>(osizes.data(), osizes.size()));
if (e != Error::Ok) {
ET_LOG(Error, "WebGPU: output %zu resize failed", i);
return Error::Internal;
}
}
} catch (const std::exception& e) {
ET_LOG(Error, "WebGPU input copy / symint refresh failed: %s", e.what());
ET_LOG(Error, "WebGPU input/output resize / copy failed: %s", e.what());
return Error::Internal;
}

Expand Down
4 changes: 3 additions & 1 deletion backends/webgpu/runtime/WebGPUDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ WebGPUContext create_webgpu_context() {

// TimedWaitAny lets webgpu_wait() block on futures via wgpuInstanceWaitAny.
WGPUInstanceDescriptor instance_desc = {};
#if defined(__EMSCRIPTEN__)
// Vendored (buck) Dawn uses the older capabilities.* API; the rig's native
// Dawn and emscripten's emdawnwebgpu (emcc 4.0.19+) use requiredFeatures.
#if defined(WEBGPU_DAWN_INSTANCE_CAPABILITIES)
instance_desc.capabilities.timedWaitAnyEnable = true;
instance_desc.capabilities.timedWaitAnyMaxCount = 1;
#else
Expand Down
134 changes: 112 additions & 22 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <executorch/backends/webgpu/runtime/WebGPUCompat.h>
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>

#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <stdexcept>
Expand Down Expand Up @@ -62,6 +63,18 @@ bool vk_datatype_is_int(vkgraph::VkDataType dtype) {
}
}

// Normalize a possibly-negative dim against rank; throws (fail-loud) if OOR.
int normalize_dim(int dim, int rank, const char* op) {
if (dim < 0) {
dim += rank;
}
if (dim < 0 || dim >= rank) {
throw std::runtime_error(
std::string("WebGPU ") + op + ": dim out of range");
}
return dim;
}

} // namespace

WebGPUGraph::WebGPUGraph() = default;
Expand Down Expand Up @@ -104,11 +117,10 @@ void WebGPUGraph::update_symints_from_inputs(
throw std::runtime_error(
"select_as_symint: source tensor is not a graph input");
}
const auto& dims = tensors_[src.input_tensor_id].dims;
int dim = src.dim < 0 ? src.dim + static_cast<int>(dims.size()) : src.dim;
if (dim < 0 || dim >= static_cast<int>(dims.size())) {
throw std::runtime_error("select_as_symint: dim out of range");
}
// Live cur_dims: the source may be a dynamic-shape input.
const auto& dims = tensors_[src.input_tensor_id].cur_dims;
int dim = normalize_dim(
src.dim, static_cast<int>(dims.size()), "select_as_symint");
int index = src.index;
if (index < 0) {
index += static_cast<int>(dims[dim]);
Expand All @@ -129,20 +141,26 @@ void WebGPUGraph::update_symints_from_inputs(
}
// Reads the [0,..,index,..,0] element; symint sources are scalar-ish.
const int64_t offset = static_cast<int64_t>(index) * stride;
// elem_size back-derived from build-time numel (sources are static-shaped).
const void* host = inputs[pos].data;
const size_t elem_size = inputs[pos].nbytes / static_cast<size_t>(numel);
// Interpret the HOST buffer by its scalar type, not the tensor's serialized
// elem_size: copy_inputs narrows an int64 host input to an int32 buffer, so
// elem_size (buffer-derived) would misread int64 host data as int32.
int32_t val;
if (elem_size == sizeof(int64_t)) {
if (inputs[pos].host_is_int64) {
val = static_cast<int32_t>(static_cast<const int64_t*>(host)[offset]);
} else if (elem_size == sizeof(int32_t)) {
val = static_cast<const int32_t*>(host)[offset];
} else {
throw std::runtime_error(
"select_as_symint: unsupported input element size");
val = static_cast<const int32_t*>(host)[offset];
}
set_symint(src.symint_id, val);
}
// sym_size.int: SymInt = a tensor's live dim (cur_dims). Usually unused (ops
// read cur_dims directly); for an intermediate source cur_dims is the build
// max here (hooks run later in propagate_resize), which is fine while unused.
for (const auto& s : symint_dim_sources_) {
const auto& d = tensors_[s.tensor_id].cur_dims;
int dim = normalize_dim(s.dim, static_cast<int>(d.size()), "sym_size");
set_symint(s.symint_id, static_cast<int32_t>(d[dim]));
}
}

void WebGPUGraph::set_symint(int id, int32_t val) {
Expand All @@ -158,16 +176,78 @@ void WebGPUGraph::set_symint(int id, int32_t val) {
}
}

void WebGPUGraph::set_cur_dims(
int value_id,
const std::vector<int64_t>& new_dims) {
auto& t = tensors_[value_id];
if (new_dims.size() != t.dims.size()) {
throw std::runtime_error("WebGPU resize: tensor rank changed");
}
size_t numel = 1;
for (size_t d = 0; d < new_dims.size(); d++) {
// 0-sized dims unsupported: live shapes are always in [1, max] per dim.
if (new_dims[d] <= 0) {
throw std::runtime_error("WebGPU resize: new dim must be positive");
}
if (new_dims[d] > t.dims[d]) {
throw std::runtime_error(
"WebGPU resize: new dim exceeds the max (serialized) allocation");
}
numel *= static_cast<size_t>(new_dims[d]);
}
const size_t new_nbytes = numel * t.elem_size;
if (t.cur_dims != new_dims) {
t.cur_dims = new_dims;
t.cur_nbytes = new_nbytes;
dirty_tensors_.insert(value_id);
}
}

void WebGPUGraph::resize_input(
int value_id,
const std::vector<int64_t>& new_dims) {
if (std::find(input_ids_.begin(), input_ids_.end(), value_id) ==
input_ids_.end()) {
throw std::runtime_error(
"WebGPUGraph::resize_input: value_id is not a graph input");
}
set_cur_dims(value_id, new_dims);
}

void WebGPUGraph::propagate_resize() {
if (dirty_symints_.empty()) {
if (dirty_symints_.empty() && dirty_tensors_.empty()) {
return;
}
// Hooks fire in registration (topological) order: operands update first.
for (auto& hook : resize_hooks_) {
if (dirty_symints_.count(hook.symint_id) != 0) {
hook.fn(*this);
}
}
dirty_symints_.clear();
// Tensor hooks: bounded fixpoint. A hook may dirty its output (cascading to a
// consumer); each pass handles the currently-dirty set. A forward DAG
// converges in <= depth passes (set_cur_dims re-dirties only on a change).
for (size_t pass = 0;
!dirty_tensors_.empty() && pass <= tensor_resize_hooks_.size();
pass++) {
std::unordered_set<int> processing;
processing.swap(dirty_tensors_);
for (auto& hook : tensor_resize_hooks_) {
if (processing.count(hook.trigger_tensor_id) != 0) {
hook.fn(*this);
}
}
}
if (!dirty_tensors_.empty()) {
throw std::runtime_error(
"WebGPU resize: tensor resize hooks did not converge");
}
// Tensor hooks must not set_symint (dirty_symints_ already drained above).
if (!dirty_symints_.empty()) {
throw std::runtime_error(
"WebGPU resize: a tensor resize hook set a SymInt; not supported");
}
}

WebGPUGraph::~WebGPUGraph() {
Expand Down Expand Up @@ -322,6 +402,10 @@ void WebGPUGraph::build(
tensor.elem_size = vk_datatype_size(vk_tensor->datatype());
tensor.is_int = vk_datatype_is_int(vk_tensor->datatype());
tensor.nbytes = numel * tensor.elem_size;
// Live dims start == max (serialized upper bound); resize_input shrinks
// them per call. Static graphs keep cur == max forever.
tensor.cur_dims = tensor.dims;
tensor.cur_nbytes = tensor.nbytes;

int constant_id = vk_tensor->constant_id();
int mem_obj_id = vk_tensor->mem_obj_id();
Expand Down Expand Up @@ -624,17 +708,20 @@ void WebGPUGraph::copy_inputs(const std::vector<InputData>& inputs) {
}
int tid = input_ids_[i];
const auto& tensor = tensors_[tid];
// Upload only the live (cur) bytes, not the max allocation; cur_nbytes ==
// nbytes on a static graph, so this is byte-identical there.
const size_t live_nbytes = tensor.cur_nbytes;

// Fast path: host and GPU element types match byte-for-byte.
if (in.nbytes == tensor.nbytes) {
wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, tensor.nbytes);
if (in.nbytes == live_nbytes) {
wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, live_nbytes);
continue;
}

// Narrow int64 host indices into the int32 buffer (mirrors Vulkan).
const bool buffer_is_int32 = tensor.is_int && tensor.elem_size == 4;
if (in.host_is_int64 && buffer_is_int32 && in.nbytes == tensor.nbytes * 2) {
const size_t numel = tensor.nbytes / 4;
if (in.host_is_int64 && buffer_is_int32 && in.nbytes == live_nbytes * 2) {
const size_t numel = live_nbytes / 4;
const int64_t* src = static_cast<const int64_t*>(in.data);
std::vector<int32_t> narrowed(numel);
for (size_t e = 0; e < numel; e++) {
Expand All @@ -648,15 +735,15 @@ void WebGPUGraph::copy_inputs(const std::vector<InputData>& inputs) {
narrowed[e] = static_cast<int32_t>(src[e]);
}
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, narrowed.data(), tensor.nbytes);
queue_, tensor.buffer, 0, narrowed.data(), live_nbytes);
continue;
}

throw std::runtime_error(
"WebGPU: unsupported input copy for input " + std::to_string(i) +
" (host " + std::to_string(in.nbytes) + " bytes" +
(in.host_is_int64 ? " int64" : "") + " vs buffer " +
std::to_string(tensor.nbytes) + " bytes)");
std::to_string(live_nbytes) + " bytes)");
}
}

Expand Down Expand Up @@ -727,15 +814,15 @@ void WebGPUGraph::execute() {
wgpuComputePassEncoderSetBindGroup(
pass, 0, dispatch.bind_group, 0, nullptr);
wgpuComputePassEncoderDispatchWorkgroups(
pass, dispatch.workgroup_count_x, 1, 1);
pass, dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
#ifdef WGPU_BACKEND_ENABLE_PROFILING
if (qp) {
qp->record(
static_cast<uint32_t>(i),
dispatch.kernel_name,
{dispatch.workgroup_count_x, 1, 1},
{dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1},
{1, 1, 1});
}
#endif // WGPU_BACKEND_ENABLE_PROFILING
Expand Down Expand Up @@ -807,7 +894,10 @@ void WebGPUGraph::execute() {
wgpuComputePassEncoderSetBindGroup(
pass, 0, dispatches_[i].bind_group, 0, nullptr);
wgpuComputePassEncoderDispatchWorkgroups(
pass, dispatches_[i].workgroup_count_x, 1, 1);
pass,
dispatches_[i].workgroup_count_x,
dispatches_[i].workgroup_count_y,
1);
wgpuComputePassEncoderEnd(pass);
wgpuComputePassEncoderRelease(pass);
}
Expand Down
Loading
Loading