Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ set(WEBGPU_SRCS
runtime/ops/slice/Slice.cpp
runtime/ops/permute/Permute.cpp
runtime/ops/cat/Cat.cpp
runtime/ops/index/Index.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
188 changes: 188 additions & 0 deletions backends/webgpu/runtime/ops/index/Index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/index/index_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

struct IndexParams {
uint32_t numel;
uint32_t _pad[3]; // pad to 16 bytes
};

// aten.index.Tensor 1D-self gather out[i]=self[index[i]] (mirrors Vulkan).
void index_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, indices (Tensor?[] -> ValueList), out].
const int self_id = args.at(0);
const int list_id = args.at(1);
const int out_id = args.at(args.size() - 1);

if (graph.get_value_type(self_id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: self arg is not a tensor");
}
if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: out arg is not a tensor");
}
if (graph.get_value_type(list_id) != WebGPUGraph::ValueType::ValueList) {
throw std::runtime_error("index: indices arg is not a ValueList");
}

// Exactly one non-Null index tensor (mirror Vulkan IndexTensor.cpp:67-69).
const std::vector<int>& ids = graph.get_value_list(list_id);
int index_id = -1;
for (int id : ids) {
if (graph.get_value_type(id) == WebGPUGraph::ValueType::Null) {
continue;
}
if (graph.get_value_type(id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: index list element is not a tensor");
}
if (index_id != -1) {
throw std::runtime_error("index: expected exactly one index tensor");
}
index_id = id;
}
if (index_id == -1) {
throw std::runtime_error("index: no index tensor provided");
}

WGPUDevice device = graph.device();

const auto& self_tensor = graph.get_tensor(self_id);
const auto& index_tensor = graph.get_tensor(index_id);
const auto& out_tensor = graph.get_tensor(out_id);

const size_t out_numel = out_tensor.nbytes / sizeof(float);
if (out_tensor.nbytes != out_numel * sizeof(float) ||
self_tensor.nbytes % sizeof(float) != 0) {
throw std::runtime_error("index: non-fp32 self/out (nbytes != numel * 4)");
}
// Index is the int32 downcast of the int64 advanced index (downcast_64_bit).
const size_t index_numel = index_tensor.nbytes / sizeof(int32_t);
if (index_tensor.nbytes != index_numel * sizeof(int32_t)) {
throw std::runtime_error("index: index buffer is not int32 (nbytes % 4)");
}
// out is one self element per index element (row_width == 1, 1D self).
if (out_numel != index_numel) {
throw std::runtime_error("index: out numel != index numel");
}

uint32_t num_elements = static_cast<uint32_t>(out_numel);
uint32_t wg_size = utils::clamp_workgroup_size(device, kIndexWorkgroupSizeX);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "index");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

IndexParams params = {};
params.numel = num_elements;

WGPUBufferDescriptor uniform_desc = {};
uniform_desc.size = sizeof(IndexParams);
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
uniform_desc.mappedAtCreation = true;
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
std::memcpy(
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(IndexParams)),
&params,
sizeof(IndexParams));
wgpuBufferUnmap(uniform_buffer);
graph.add_uniform_buffer_bytes(sizeof(IndexParams));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kIndexWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// self (read), out (read_write), index (read i32), params (uniform).
WGPUBindGroupLayoutEntry entries[4] = {};
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;
entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[3].binding = 3;
entries[3].visibility = WGPUShaderStage_Compute;
entries[3].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 4;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[4] = {};
bg_entries[0].binding = 0;
bg_entries[0].buffer = self_tensor.buffer;
bg_entries[0].size = self_tensor.nbytes;
bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;
bg_entries[2].binding = 2;
bg_entries[2].buffer = index_tensor.buffer;
bg_entries[2].size = index_tensor.nbytes;
bg_entries[3].binding = 3;
bg_entries[3].buffer = uniform_buffer;
bg_entries[3].size = sizeof(IndexParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 4;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// The bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.index.Tensor, index_impl);
}

} // namespace executorch::backends::webgpu
22 changes: 22 additions & 0 deletions backends/webgpu/runtime/ops/index/index.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read> index: array<i32>;

struct Params {
numel: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

override wg_size: u32 = 64;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= params.numel) {
return;
}

// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
let i = index[out_bufi];
output[out_bufi] = input[u32(i)];
}
46 changes: 46 additions & 0 deletions backends/webgpu/runtime/ops/index/index_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from index.wgsl - DO NOT EDIT.
// wgsl-sha256: daed48e60bfcf2b7420d277576d794137d3bff383aef4f68464c98c8a7235c8e
inline constexpr const char* kIndexWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read> index: array<i32>;

struct Params {
numel: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

override wg_size: u32 = 64;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= params.numel) {
return;
}

// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
let i = index[out_bufi];
output[out_bufi] = input[u32(i)];
}
)";

inline constexpr uint32_t kIndexWorkgroupSizeX = 64;
inline constexpr uint32_t kIndexWorkgroupSizeY = 1;
inline constexpr uint32_t kIndexWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
Loading