Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ set(WEBGPU_SRCS
runtime/ops/add/BinaryOp.cpp
runtime/ops/rms_norm/RmsNorm.cpp
runtime/ops/update_cache/UpdateCache.cpp
runtime/ops/select_as_symint/SelectAsSymint.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
9 changes: 9 additions & 0 deletions backends/webgpu/runtime/WebGPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ Error WebGPUBackend::execute(
}
graph->copy_inputs(inputs);

// Fail loud as a runtime Error so a throw never crosses the backend boundary.
try {
graph->update_symints_from_inputs(inputs);
graph->propagate_resize();
} catch (const std::exception& e) {
ET_LOG(Error, "WebGPU symint refresh/resize failed: %s", e.what());
return Error::Internal;
}

// Execute the compute graph
graph->execute();

Expand Down
111 changes: 111 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,86 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) {
return buffer;
}

void WebGPUGraph::update_symints_from_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs) {
for (const auto& src : symint_sources_) {
int pos = -1;
for (size_t i = 0; i < input_ids_.size(); i++) {
if (input_ids_[i] == src.input_tensor_id) {
pos = static_cast<int>(i);
break;
}
}
if (pos < 0 || pos >= static_cast<int>(inputs.size())) {
throw std::runtime_error(
"select_as_symint: source tensor is not a graph input");
}
const auto& dims = tensors_[src.input_tensor_id].dims;
int dim = src.dim < 0 ? src.dim + static_cast<int>(dims.size()) : src.dim;
if (dim < 0 || dim >= static_cast<int>(dims.size())) {
throw std::runtime_error("select_as_symint: dim out of range");
}
int index = src.index;
if (index < 0) {
index += static_cast<int>(dims[dim]);
}
if (index < 0 || index >= static_cast<int>(dims[dim])) {
throw std::runtime_error("select_as_symint: index out of range");
}
int64_t numel = 1;
for (int64_t d : dims) {
numel *= d;
}
if (numel <= 0) {
throw std::runtime_error("select_as_symint: empty input tensor");
}
int64_t stride = 1;
for (size_t i = static_cast<size_t>(dim) + 1; i < dims.size(); i++) {
stride *= dims[i];
}
// Reads the [0,..,index,..,0] element; symint sources are scalar-ish.
const int64_t offset = static_cast<int64_t>(index) * stride;
// elem_size back-derived from build-time numel (sources are static-shaped).
const void* host = inputs[pos].first;
const size_t elem_size = inputs[pos].second / static_cast<size_t>(numel);
int32_t val;
if (elem_size == sizeof(int64_t)) {
val = static_cast<int32_t>(static_cast<const int64_t*>(host)[offset]);
} else if (elem_size == sizeof(int32_t)) {
val = static_cast<const int32_t*>(host)[offset];
} else {
throw std::runtime_error(
"select_as_symint: unsupported input element size");
}
set_symint(src.symint_id, val);
}
}

void WebGPUGraph::set_symint(int id, int32_t val) {
auto it = symints_.find(id);
if (it == symints_.end()) {
throw std::runtime_error("WebGPUGraph::set_symint: id is not a SymInt");
}
if (it->second.value != val) {
it->second.value = val;
wgpuQueueWriteBuffer(
queue_, it->second.buffer, 0, &it->second.value, sizeof(int32_t));
dirty_symints_.insert(id);
}
}

void WebGPUGraph::propagate_resize() {
if (dirty_symints_.empty()) {
return;
}
for (auto& hook : resize_hooks_) {
if (dirty_symints_.count(hook.symint_id) != 0) {
hook.fn(*this);
}
}
dirty_symints_.clear();
}

WebGPUGraph::~WebGPUGraph() {
for (size_t i = 0; i < tensors_.size(); i++) {
if (tensors_[i].buffer &&
Expand All @@ -76,6 +156,16 @@ WebGPUGraph::~WebGPUGraph() {
wgpuBufferRelease(buf);
}
}
for (auto& buf : owned_uniform_buffers_) {
if (buf) {
wgpuBufferRelease(buf);
}
}
for (auto& kv : symints_) {
if (kv.second.buffer) {
wgpuBufferRelease(kv.second.buffer);
}
}
for (auto& buf : output_staging_buffers_) {
if (buf) {
wgpuBufferRelease(buf);
Expand Down Expand Up @@ -236,6 +326,27 @@ void WebGPUGraph::build(
bools_[i] = val->value_as_Bool()->bool_val();
break;
}
case vkgraph::GraphTypes::SymInt: {
// Live scalar: small Uniform buffer the CPU rewrites per execute.
value_types_[i] = ValueType::SymInt;
SymIntSlot slot;
slot.value = static_cast<int32_t>(val->value_as_SymInt()->value());
// 16B matches the backend uniform-struct alignment; int32 in first 4.
constexpr size_t kSymIntUniformBytes = 16;
WGPUBufferDescriptor d = {};
d.size = kSymIntUniformBytes;
d.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
d.mappedAtCreation = true;
slot.buffer = wgpuDeviceCreateBuffer(device_, &d);
void* mapped =
wgpuBufferGetMappedRange(slot.buffer, 0, kSymIntUniformBytes);
std::memset(mapped, 0, kSymIntUniformBytes);
std::memcpy(mapped, &slot.value, sizeof(int32_t));
wgpuBufferUnmap(slot.buffer);
symints_[i] = slot;
add_uniform_buffer_bytes(kSymIntUniformBytes);
break;
}
default:
value_types_[i] = ValueType::Null;
break;
Expand Down
74 changes: 73 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
#include <webgpu/webgpu.h>

#include <cstdint>
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include <executorch/runtime/core/named_data_map.h>
Expand Down Expand Up @@ -104,6 +106,52 @@ class WebGPUGraph {
return ints_[id];
}

// Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO.
// set_symint writes the buffer + marks dirty only if the value changed.
void set_symint(int id, int32_t val);
// read_symint throws (fail-loud) if id is not a SymInt.
int32_t read_symint(int id) const {
return symints_.at(id).value;
}
// symint_buffer throws (fail-loud) if id is not a SymInt.
WGPUBuffer symint_buffer(int id) const {
return symints_.at(id).buffer;
}

// Records that a SymInt's value is read from input_tensor[index] along dim.
struct SymIntSource {
int symint_id;
int input_tensor_id;
int dim;
int index;
};
void
add_symint_source(int symint_id, int input_tensor_id, int dim, int index) {
symint_sources_.push_back({symint_id, input_tensor_id, dim, index});
}
const std::vector<SymIntSource>& symint_sources() const {
return symint_sources_;
}

// Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl.
void update_symints_from_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs);

// Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize.
void add_resize_hook(int symint_id, std::function<void(WebGPUGraph&)> fn) {
resize_hooks_.push_back({symint_id, std::move(fn)});
}
// Run hooks for changed SymInts then clear; call before execute().
void propagate_resize();

// Mutable dispatch access for resize hooks (to rewrite workgroup_count_x).
WebGPUDispatch& dispatch_at(size_t i) {
return dispatches_[i];
}
size_t num_dispatches() const {
return dispatches_.size();
}

WGPUDevice device() const {
return device_;
}
Expand All @@ -119,6 +167,11 @@ class WebGPUGraph {
uniform_buffer_bytes_ += bytes;
}

// Keep a uniform alive for the graph's lifetime; released in the dtor.
void own_uniform_buffer(WGPUBuffer buffer) {
owned_uniform_buffers_.push_back(buffer);
}

// Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA).
WGPUBuffer create_scratch_buffer(size_t nbytes);

Expand Down Expand Up @@ -149,7 +202,7 @@ class WebGPUGraph {
return static_cast<int>(value_types_.size());
}

enum class ValueType { Tensor, Int, Double, Bool, Null, String };
enum class ValueType { Tensor, Int, Double, Bool, Null, String, SymInt };

ValueType get_value_type(int id) const {
return value_types_[id];
Expand All @@ -168,6 +221,22 @@ class WebGPUGraph {
std::vector<double> doubles_;
std::vector<bool> bools_;

// SymInt (live scalar): id -> {live Uniform buffer, current value}, sparse.
struct SymIntSlot {
WGPUBuffer buffer = nullptr;
int32_t value = 0;
};
std::unordered_map<int, SymIntSlot> symints_;
std::vector<SymIntSource> symint_sources_;

// Resize hooks + the set of SymInts changed since the last propagate_resize.
struct ResizeHook {
int symint_id;
std::function<void(WebGPUGraph&)> fn;
};
std::vector<ResizeHook> resize_hooks_;
std::unordered_set<int> dirty_symints_;

std::vector<int> input_ids_;
std::vector<int> output_ids_;

Expand All @@ -179,6 +248,9 @@ class WebGPUGraph {
// Long-lived scratch storage buffers for fused ops (e.g. SDPA temporaries).
std::vector<WGPUBuffer> scratch_buffers_;

// Uniform buffers owned for the graph's lifetime; released in the dtor.
std::vector<WGPUBuffer> owned_uniform_buffers_;

// Staging buffers for reading back outputs (MapRead | CopyDst).
std::vector<WGPUBuffer> output_staging_buffers_;

Expand Down
41 changes: 41 additions & 0 deletions backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>

#include <stdexcept>

namespace executorch::backends::webgpu {

namespace {

// et_vk.select_as_symint: out SymInt = x[index] along dim; read at execute.
void select_as_symint_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int x_id = args.at(0);
const int dim_id = args.at(1);
const int index_id = args.at(2);
const int out_id = args.at(3);

if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::SymInt) {
throw std::runtime_error("select_as_symint: output is not a SymInt");
}
graph.add_symint_source(
out_id,
x_id,
static_cast<int>(graph.get_int(dim_id)),
static_cast<int>(graph.get_int(index_id)));
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl);
}

} // namespace executorch::backends::webgpu
Loading