Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(WEBGPU_SRCS
runtime/ops/select_as_symint/SelectAsSymint.cpp
runtime/ops/quantized_linear/QuantizedLinear.cpp
runtime/ops/mul/BinaryOp.cpp
runtime/ops/sigmoid/Sigmoid.cpp
runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
runtime/ops/rope/RotaryEmbedding.cpp
runtime/ops/prepack/Prepack.cpp
Expand Down
20 changes: 8 additions & 12 deletions backends/webgpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Run ExecuTorch models on the GPU via [WebGPU](https://www.w3.org/TR/webgpu/). The backend compiles delegated subgraphs into WGSL compute shaders executed natively through [Dawn](https://dawn.googlesource.com/dawn), whose Tint compiler is the reference WGSL implementation (Metal on macOS, Vulkan on Linux/Windows).

> **Status: Prototype, under active development.** The backend runs the core of transformer inference today — `add`, `rms_norm`, fused scaled-dot-product attention with KV cache, and 4-bit weight-only quantized linear — plus quantized embedding, rotary embedding, and constant prepacking. See [Progress](#progress) for shipped milestones.
> **Status: Prototype, under active development.** The backend runs the core of transformer inference today — `add`, `mul`, `sigmoid`, `rms_norm`, fused scaled-dot-product attention with KV cache, and 4-bit weight-only quantized linear — plus quantized embedding, rotary embedding, and constant prepacking. See [Progress](#progress) for shipped milestones.

## Progress

Expand All @@ -20,14 +20,7 @@ Milestones landed on `main`:
| 2026-06 | Added the attention core of transformer inference — fused scaled-dot-product attention (`sdpa_with_kv_cache`) with an `update_cache` operator for autoregressive decode | [#20086](https://github.com/pytorch/executorch/pull/20086), [#20087](https://github.com/pytorch/executorch/pull/20087) |
| 2026-06 | Added on-GPU kernel timing via WebGPU timestamp queries, for true GPU-side profiling | [#20201](https://github.com/pytorch/executorch/pull/20201) |
| 2026-06 | Added the dominant compute in quantized LLMs — 4-bit weight-only quantized linear (`linear_q4gsw`), a dequantize-and-matmul kernel | [#20226](https://github.com/pytorch/executorch/pull/20226), [#20227](https://github.com/pytorch/executorch/pull/20227) |

In review:

| Milestone | Pull Request |
|---|---|
| Adds 4-bit quantized embedding (`embedding_q4gsw`) | [#20263](https://github.com/pytorch/executorch/pull/20263) |
| Adds rotary position embedding / RoPE (`apply_rotary_emb`) | [#20264](https://github.com/pytorch/executorch/pull/20264) |
| Adds constant prepacking (`prepack`) for end-to-end model weight handling | [#20265](https://github.com/pytorch/executorch/pull/20265) |
| 2026-06 | Added token embedding, rotary position embedding, and constant prepacking for end-to-end model weight handling | [#20414](https://github.com/pytorch/executorch/pull/20414) |

## Architecture

Expand Down Expand Up @@ -61,14 +54,17 @@ Key design choices:
| Operator | WGSL Shader | Notes |
|---|---|---|
| `aten.add.Tensor` | `binary_add.wgsl` | Element-wise with alpha: `out = in1 + alpha * in2` |
| `aten.mul.Tensor` | `binary_mul.wgsl` | Element-wise multiply with broadcasting |
| `aten.sigmoid.default` | `sigmoid.wgsl` | Element-wise sigmoid activation |
| `et_vk.rms_norm.default` | `rms_norm.wgsl` | Root-mean-square normalization |
| `sdpa_with_kv_cache.default` | `sdpa_compute_attn_weights.wgsl`, `sdpa_softmax.wgsl`, `sdpa_compute_out.wgsl` | Fused scaled-dot-product attention (QK / softmax / AV) with KV cache |
| `llama.update_cache.default` | `update_cache.wgsl` | In-place KV cache update for autoregressive decode |
| `et_vk.linear_q4gsw.default` | `q4gsw_linear.wgsl` | 4-bit weight-only quantized linear (dequantize + matmul) |
| `et_vk.embedding_q4gsw.default` | `embedding_q4gsw.wgsl` | 4-bit groupwise-symmetric quantized embedding |
| `et_vk.apply_rotary_emb.default` | `rotary_embedding.wgsl` | Interleaved rotary positional embedding |
| `et_vk.prepack.default` | N/A | Constant materialization into GPU buffers |

**In review:** quantized embedding (`embedding_q4gsw`), rotary embedding (`apply_rotary_emb`), and constant prepacking (`prepack`).

**Planned:** `mul`, `sigmoid`, shape ops (`view`, `permute`, `slice`, `select`, `cat`, `squeeze`/`unsqueeze`), and `index` — the remaining ops needed for end-to-end Llama 3.2 1B.
**Planned:** shape ops (`view`, `permute`, `slice`, `select`, `cat`, `squeeze`/`unsqueeze`) and `index` — the remaining ops needed for end-to-end Llama 3.2 1B.

## Quick Start

Expand Down
7 changes: 5 additions & 2 deletions backends/webgpu/TODO.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# WebGPU Backend — TODO

## Current State (Prototype)
- Single op: `aten.add.Tensor` (fp32, buffer storage)
- Runtime support for transformer-oriented fp32 and LLM custom ops, including
`aten.add.Tensor`, `aten.mul.Tensor`, `aten.sigmoid.default`,
`et_vk.rms_norm.default`,
fused SDPA with KV cache, 4-bit quantized linear/embedding, RoPE, and prepack.
- No Python AOT code — directly consumes Vulkan delegate (.pte exported via VulkanPartitioner)
- Reuses Vulkan FlatBuffer format (VH00 header + VK00 payload)
- Registers as `"VulkanBackend"` at runtime — mutually exclusive with Vulkan backend at link time
Expand Down Expand Up @@ -30,7 +33,7 @@ element-wise ops (add→relu→mul→clamp) at compile time. Embed via the exist
`shaders: [VkBytes]` field in schema.fbs.

## Next Steps
1. **More ops**: sub, mul, relu, linear (matmul), softmax, layer_norm
1. **More ops**: sub, relu, linear (matmul), softmax, layer_norm, shape ops
2. **fp16 support**: Feature-detect `shader-f16`, fallback to fp32
3. **Buffer pooling**: Reuse GPU buffers to avoid OOM at scale
4. **Pipeline caching**: Cache compiled pipelines across runs
Expand Down
137 changes: 137 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/Sigmoid.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
#include <executorch/backends/webgpu/runtime/ops/sigmoid/sigmoid_wgsl.h>

#include <webgpu/webgpu.h>

#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

void sigmoid_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// aten.sigmoid.default args: [in, out]
const int in_id = args.at(0);
const int out_id = args.at(1);

WGPUDevice device = graph.device();

const auto& in_tensor = graph.get_tensor(in_id);
const auto& out_tensor = graph.get_tensor(out_id);

if (in_tensor.dims != out_tensor.dims) {
throw std::runtime_error("sigmoid: input and output shapes must match");
}

TensorMeta out_meta;
fill_tensor_meta(out_tensor, &out_meta);

if (out_tensor.nbytes !=
static_cast<size_t>(out_meta.numel) * sizeof(float) ||
in_tensor.nbytes != static_cast<size_t>(out_meta.numel) * sizeof(float)) {
throw std::runtime_error("sigmoid: non-fp32 operand (nbytes != numel * 4)");
}

uint32_t wg_size =
utils::clamp_workgroup_size(device, kSigmoidWorkgroupSizeX);
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
device, out_meta.numel, wg_size, "sigmoid");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

WGPUBuffer out_meta_buf =
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
graph.add_uniform_buffer_bytes(sizeof(TensorMeta));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kSigmoidWGSL, WGPU_STRLEN};

WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

WGPUBindGroupLayoutEntry entries[3] = {};

entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;

entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;

entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 3;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[3] = {};

bg_entries[0].binding = 0;
bg_entries[0].buffer = in_tensor.buffer;
bg_entries[0].size = in_tensor.nbytes;

bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;

bg_entries[2].binding = 2;
bg_entries[2].buffer = out_meta_buf;
bg_entries[2].size = sizeof(TensorMeta);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 3;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref; the bind group keeps the uniform alive until release.
wgpuBufferRelease(out_meta_buf);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.sigmoid.default, sigmoid_impl);
}

} // namespace executorch::backends::webgpu
21 changes: 21 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/sigmoid.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= out_meta.numel) {
return;
}
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
}
45 changes: 45 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/sigmoid_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from sigmoid.wgsl - DO NOT EDIT.
// wgsl-sha256: 73a26ddce78d1cbd6cbb0c586791b338153cea9af13790dc1400516128a4c278
inline constexpr const char* kSigmoidWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= out_meta.numel) {
return;
}
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
}
)";

inline constexpr uint32_t kSigmoidWorkgroupSizeX = 64;
inline constexpr uint32_t kSigmoidWorkgroupSizeY = 1;
inline constexpr uint32_t kSigmoidWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
26 changes: 26 additions & 0 deletions backends/webgpu/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@ python_unittest(
],
)

python_unittest(
name = "test_mul",
srcs = [
"ops/mul/test_mul.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/exir:lib",
],
)

python_unittest(
name = "test_sigmoid",
srcs = [
"ops/sigmoid/test_sigmoid.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/exir:lib",
],
)

runtime.python_library(
name = "tester",
srcs = ["tester.py"],
Expand Down
Loading
Loading