Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ set(WEBGPU_SRCS
runtime/ops/prepack/Prepack.cpp
runtime/ops/view_copy/ViewCopy.cpp
runtime/ops/select/Select.cpp
runtime/ops/sigmoid/UnaryOp.cpp
runtime/ops/squeeze/Squeeze.cpp
runtime/ops/unsqueeze/Unsqueeze.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
165 changes: 165 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/UnaryOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/sigmoid/sigmoid_wgsl.h>

#include <webgpu/webgpu.h>

#include <stdexcept>
#include <string>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

// Uniform buffer layout matching the WGSL Params struct; 16-byte aligned.
struct UnaryParams {
uint32_t num_elements;
uint32_t _pad[3];
};

// Generic elementwise unary op; mirrors Vulkan add_unary_op_node (UnaryOp.cpp).
void add_unary_op(
WebGPUGraph& graph,
int in_id,
int out_id,
const char* wgsl_source,
uint32_t wg_size_x,
const char* op_name) {
WGPUDevice device = graph.device();

const auto& in_tensor = graph.get_tensor(in_id);
const auto& out_tensor = graph.get_tensor(out_id);
if (in_tensor.buffer == nullptr || out_tensor.buffer == nullptr) {
throw std::runtime_error(std::string(op_name) + ": null buffer binding");
}

// 4-byte (fp32) alignment guard on both operands; also the dtype guard.
if (in_tensor.nbytes % sizeof(float) != 0 ||
out_tensor.nbytes % sizeof(float) != 0) {
throw std::runtime_error(
std::string(op_name) + ": operand not 4-byte aligned");
}
if (in_tensor.nbytes != out_tensor.nbytes) {
throw std::runtime_error(
std::string(op_name) + ": input/output size mismatch");
}

uint32_t num_elements =
static_cast<uint32_t>(out_tensor.nbytes / sizeof(float));

uint32_t wg_size = utils::clamp_workgroup_size(device, wg_size_x);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, num_elements, wg_size, op_name);

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

UnaryParams params = {};
params.num_elements = num_elements;

WGPUBuffer uniform_buffer =
utils::make_uniform(device, &params, sizeof(UnaryParams));
graph.add_uniform_buffer_bytes(sizeof(UnaryParams));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {wgsl_source, WGPU_STRLEN};

WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// Bind group layout: input (read storage) + output (storage) + params.
WGPUBindGroupLayoutEntry entries[3] = {};

entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;

entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;

entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 3;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[3] = {};

bg_entries[0].binding = 0;
bg_entries[0].buffer = in_tensor.buffer;
bg_entries[0].size = in_tensor.nbytes;

bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;

bg_entries[2].binding = 2;
bg_entries[2].buffer = uniform_buffer;
bg_entries[2].size = sizeof(UnaryParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 3;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});

// Release intermediates (pipeline + bind_group are kept by dispatch).
wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref; the bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
}

void sigmoid_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// aten.sigmoid.default args: [in, out]
add_unary_op(
graph,
args.at(0),
args.at(1),
kSigmoidWGSL,
kSigmoidWorkgroupSizeX,
"sigmoid");
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.sigmoid.default, sigmoid_impl);
}

} // namespace executorch::backends::webgpu
18 changes: 18 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/sigmoid.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct Params {
num_elements: u32,
}
@group(0) @binding(2) var<uniform> params: Params;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.num_elements) {
return;
}
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
}
42 changes: 42 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/sigmoid_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from sigmoid.wgsl - DO NOT EDIT.
// wgsl-sha256: 70395dbb107b8b95ae13c0a6fb12a8415c561c645da0347294c92904314ae84c
inline constexpr const char* kSigmoidWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct Params {
num_elements: u32,
}
@group(0) @binding(2) var<uniform> params: Params;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.num_elements) {
return;
}
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
}
)";

inline constexpr uint32_t kSigmoidWorkgroupSizeX = 64;
inline constexpr uint32_t kSigmoidWorkgroupSizeY = 1;
inline constexpr uint32_t kSigmoidWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
31 changes: 31 additions & 0 deletions backends/webgpu/runtime/ops/squeeze/Squeeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/view_copy/view_copy.h>

#include <vector>

namespace executorch::backends::webgpu {

namespace {

// squeeze_copy.dims = numel-preserving flat copy (Vulkan Squeeze.cpp:102-104).
void squeeze_copy_dims_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, dims, out]; dims ignored (out shape fixed AOT).
add_flat_copy(graph, args.at(0), args.at(args.size() - 1));
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.squeeze_copy.dims, squeeze_copy_dims_impl);
}

} // namespace executorch::backends::webgpu
31 changes: 31 additions & 0 deletions backends/webgpu/runtime/ops/unsqueeze/Unsqueeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/view_copy/view_copy.h>

#include <vector>

namespace executorch::backends::webgpu {

namespace {

// unsqueeze_copy = numel-preserving flat copy (Vulkan Unsqueeze.cpp:101-103).
void unsqueeze_copy_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, dim, out]; dim ignored (out shape fixed AOT, like view_copy).
add_flat_copy(graph, args.at(0), args.at(args.size() - 1));
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.unsqueeze_copy.default, unsqueeze_copy_impl);
}

} // namespace executorch::backends::webgpu
67 changes: 67 additions & 0 deletions backends/webgpu/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@
CONFIGS as _SELECT_CONFIGS,
SelectModule,
)
from executorch.backends.webgpu.test.ops.test_sigmoid import (
_det_input as _sigmoid_det_input,
N as _SIGMOID_N,
SigmoidModule,
)

from executorch.backends.webgpu.test.ops.test_squeeze import (
CONFIGS as _SQUEEZE_CONFIGS,
SqueezeModule,
)

from executorch.backends.webgpu.test.ops.test_unsqueeze import (
CONFIGS as _UNSQUEEZE_CONFIGS,
UnsqueezeModule,
)
from executorch.backends.webgpu.test.ops.test_view_copy import (
CONFIGS as _VIEW_CONFIGS,
ViewModule,
Expand Down Expand Up @@ -153,3 +168,55 @@ def _view_copy_suite() -> WebGPUTestSuite:
@register_op_test("select")
def _select_suite() -> WebGPUTestSuite:
return _fn_config_suite(SelectModule, _SELECT_CONFIGS)


def _sigmoid_full_range(_shape) -> torch.Tensor:
# Reuses the monolith's saturation-tail input (linspace(-12, 12)).
return _sigmoid_det_input()


@register_op_test("sigmoid")
def _sigmoid_suite() -> WebGPUTestSuite:
# sigmoid has no CONFIGS table; cover unary shapes directly (tol 1e-4).
return WebGPUTestSuite(
module_factory=lambda: SigmoidModule(),
cases=[
Case(name="vec", inputs=((M1,),)),
Case(name="mat", inputs=((M1, M2),)),
Case(name="rank3", inputs=((S1, M1, M2),)),
Case(name="rank4", inputs=((S1, S2, S2, M2),)),
# Saturation tails sigmoid(+-12) (~6e-6 / 0.999994) that randn shapes miss.
Case(
name="saturation",
inputs=(InputSpec(shape=(_SIGMOID_N,), gen=_sigmoid_full_range),),
),
],
atol=1e-4,
rtol=1e-4,
)


@register_op_test("squeeze")
def _squeeze_suite() -> WebGPUTestSuite:
# CONFIGS: name -> (shape, dim) where dim is an int or a tuple.
return WebGPUTestSuite(
module_factory=lambda dim: SqueezeModule(dim),
cases=[
Case(name=n, construct={"dim": dim}, inputs=(shape,))
for n, (shape, dim) in _SQUEEZE_CONFIGS.items()
],
golden_dtype="float32", # reshape copies values; fp64 bit-identical
)


@register_op_test("unsqueeze")
def _unsqueeze_suite() -> WebGPUTestSuite:
# CONFIGS: name -> (shape, dim).
return WebGPUTestSuite(
module_factory=lambda dim: UnsqueezeModule(dim),
cases=[
Case(name=n, construct={"dim": dim}, inputs=(shape,))
for n, (shape, dim) in _UNSQUEEZE_CONFIGS.items()
],
golden_dtype="float32", # reshape copies values; fp64 bit-identical
)
Loading
Loading