Skip to content

Commit a47bb14

Browse files
[ExecuTorch][WebGPU] Add aten.index.Tensor (1D-self gather)
Pull Request resolved: #20464 Adds the WebGPU delegate handler for aten.index.Tensor, the 1D-self advanced-index gather out[i] = self[index[i]] (output shape == index shape). This is the form the VulkanPartitioner delegates -- it requires a 1D self and exactly one non-None index (op_registry.py); 2D mask/freqs gathers stay on CPU. It mirrors the Vulkan delegate's index_tensor op (IndexTensor.cpp + index_tensor_buffer.glsl) as a single compute dispatch over the output elements, each reading the int32 index and gathering the corresponding fp32 self element. The op is composed as: - index.wgsl: one workgroup-strided pass, out[i] = self[u32(index[i])], guarded by a numel bound; buffer-only, fp32 self/out, int32 index, 1D dispatch via the shared WebGPUUtils helpers (clamp workgroup size + 1D count). - Index.cpp: validates the args (self/out tensors; indices ValueList with exactly one index tensor; fp32 self/out; int32 index; out numel == index numel), failing loud on any violation, then records the dispatch. row_width is dropped (always 1 for 1D self). ghstack-source-id: 397756251 @exported-using-ghexport @diff-train-skip-merge Differential Revision: [D109478967](https://our.internmc.facebook.com/intern/diff/D109478967/)
1 parent 5168df7 commit a47bb14

4 files changed

Lines changed: 258 additions & 0 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ set(WEBGPU_SRCS
5050
runtime/ops/slice/Slice.cpp
5151
runtime/ops/permute/Permute.cpp
5252
runtime/ops/cat/Cat.cpp
53+
runtime/ops/index/Index.cpp
5354
)
5455

5556
add_library(webgpu_backend ${WEBGPU_SRCS})
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
11+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
12+
#include <executorch/backends/webgpu/runtime/ops/index/index_wgsl.h>
13+
14+
#include <webgpu/webgpu.h>
15+
16+
#include <cstdint>
17+
#include <stdexcept>
18+
#include <vector>
19+
20+
namespace executorch::backends::webgpu {
21+
22+
namespace {
23+
24+
struct IndexParams {
25+
uint32_t numel;
26+
uint32_t _pad[3]; // pad to 16 bytes
27+
};
28+
29+
// aten.index.Tensor 1D-self gather out[i]=self[index[i]] (mirrors Vulkan).
30+
void index_impl(WebGPUGraph& graph, const std::vector<int>& args) {
31+
// args: [self, indices (Tensor?[] -> ValueList), out].
32+
const int self_id = args.at(0);
33+
const int list_id = args.at(1);
34+
const int out_id = args.at(args.size() - 1);
35+
36+
if (graph.get_value_type(self_id) != WebGPUGraph::ValueType::Tensor) {
37+
throw std::runtime_error("index: self arg is not a tensor");
38+
}
39+
if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) {
40+
throw std::runtime_error("index: out arg is not a tensor");
41+
}
42+
if (graph.get_value_type(list_id) != WebGPUGraph::ValueType::ValueList) {
43+
throw std::runtime_error("index: indices arg is not a ValueList");
44+
}
45+
46+
// Exactly one non-Null index tensor (mirror Vulkan IndexTensor.cpp:67-69).
47+
const std::vector<int>& ids = graph.get_value_list(list_id);
48+
int index_id = -1;
49+
for (int id : ids) {
50+
if (graph.get_value_type(id) == WebGPUGraph::ValueType::Null) {
51+
continue;
52+
}
53+
if (graph.get_value_type(id) != WebGPUGraph::ValueType::Tensor) {
54+
throw std::runtime_error("index: index list element is not a tensor");
55+
}
56+
if (index_id != -1) {
57+
throw std::runtime_error("index: expected exactly one index tensor");
58+
}
59+
index_id = id;
60+
}
61+
if (index_id == -1) {
62+
throw std::runtime_error("index: no index tensor provided");
63+
}
64+
65+
WGPUDevice device = graph.device();
66+
67+
const auto& self_tensor = graph.get_tensor(self_id);
68+
const auto& index_tensor = graph.get_tensor(index_id);
69+
const auto& out_tensor = graph.get_tensor(out_id);
70+
71+
if (self_tensor.buffer == nullptr || index_tensor.buffer == nullptr ||
72+
out_tensor.buffer == nullptr) {
73+
throw std::runtime_error("index: null buffer binding");
74+
}
75+
// 1D-self gather: the kernel flat-indexes self by a scalar; fail loud on a
76+
// higher-rank self (mirrors Vulkan index_tensor_buffer's 1D-self contract).
77+
if (self_tensor.dims.size() != 1) {
78+
throw std::runtime_error("index: only 1D self is supported");
79+
}
80+
81+
const size_t out_numel = out_tensor.nbytes / sizeof(float);
82+
if (out_tensor.nbytes != out_numel * sizeof(float) ||
83+
self_tensor.nbytes % sizeof(float) != 0) {
84+
throw std::runtime_error("index: non-fp32 self/out (nbytes != numel * 4)");
85+
}
86+
// Index is the int32 downcast of the int64 advanced index (downcast_64_bit).
87+
const size_t index_numel = index_tensor.nbytes / sizeof(int32_t);
88+
if (index_tensor.nbytes != index_numel * sizeof(int32_t)) {
89+
throw std::runtime_error("index: index buffer is not int32 (nbytes % 4)");
90+
}
91+
// out is one self element per index element (row_width == 1, 1D self).
92+
if (out_numel != index_numel) {
93+
throw std::runtime_error("index: out numel != index numel");
94+
}
95+
96+
uint32_t num_elements = static_cast<uint32_t>(out_numel);
97+
uint32_t wg_size = utils::clamp_workgroup_size(device, kIndexWorkgroupSizeX);
98+
uint32_t workgroup_count =
99+
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "index");
100+
101+
WGPUConstantEntry wg_size_constant = {};
102+
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
103+
wg_size_constant.value = static_cast<double>(wg_size);
104+
105+
IndexParams params = {};
106+
params.numel = num_elements;
107+
108+
WGPUBuffer uniform_buffer =
109+
utils::make_uniform(device, &params, sizeof(IndexParams));
110+
graph.add_uniform_buffer_bytes(sizeof(IndexParams));
111+
112+
WGPUShaderSourceWGSL wgsl_desc = {};
113+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
114+
wgsl_desc.code = {kIndexWGSL, WGPU_STRLEN};
115+
WGPUShaderModuleDescriptor shader_desc = {};
116+
shader_desc.nextInChain = &wgsl_desc.chain;
117+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
118+
119+
// self (read), out (read_write), index (read i32), params (uniform).
120+
WGPUBindGroupLayoutEntry entries[4] = {};
121+
entries[0].binding = 0;
122+
entries[0].visibility = WGPUShaderStage_Compute;
123+
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
124+
entries[1].binding = 1;
125+
entries[1].visibility = WGPUShaderStage_Compute;
126+
entries[1].buffer.type = WGPUBufferBindingType_Storage;
127+
entries[2].binding = 2;
128+
entries[2].visibility = WGPUShaderStage_Compute;
129+
entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
130+
entries[3].binding = 3;
131+
entries[3].visibility = WGPUShaderStage_Compute;
132+
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
133+
134+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
135+
bgl_desc.entryCount = 4;
136+
bgl_desc.entries = entries;
137+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
138+
139+
WGPUPipelineLayoutDescriptor pl_desc = {};
140+
pl_desc.bindGroupLayoutCount = 1;
141+
pl_desc.bindGroupLayouts = &bgl;
142+
WGPUPipelineLayout pipeline_layout =
143+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
144+
145+
WGPUComputePipelineDescriptor pipeline_desc = {};
146+
pipeline_desc.layout = pipeline_layout;
147+
pipeline_desc.compute.module = shader;
148+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
149+
pipeline_desc.compute.constantCount = 1;
150+
pipeline_desc.compute.constants = &wg_size_constant;
151+
WGPUComputePipeline pipeline =
152+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
153+
154+
WGPUBindGroupEntry bg_entries[4] = {};
155+
bg_entries[0].binding = 0;
156+
bg_entries[0].buffer = self_tensor.buffer;
157+
bg_entries[0].size = self_tensor.nbytes;
158+
bg_entries[1].binding = 1;
159+
bg_entries[1].buffer = out_tensor.buffer;
160+
bg_entries[1].size = out_tensor.nbytes;
161+
bg_entries[2].binding = 2;
162+
bg_entries[2].buffer = index_tensor.buffer;
163+
bg_entries[2].size = index_tensor.nbytes;
164+
bg_entries[3].binding = 3;
165+
bg_entries[3].buffer = uniform_buffer;
166+
bg_entries[3].size = sizeof(IndexParams);
167+
168+
WGPUBindGroupDescriptor bg_desc = {};
169+
bg_desc.layout = bgl;
170+
bg_desc.entryCount = 4;
171+
bg_desc.entries = bg_entries;
172+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
173+
174+
graph.add_dispatch({pipeline, bind_group, workgroup_count});
175+
176+
wgpuShaderModuleRelease(shader);
177+
wgpuBindGroupLayoutRelease(bgl);
178+
wgpuPipelineLayoutRelease(pipeline_layout);
179+
// The bind group keeps the uniform buffer alive until release.
180+
wgpuBufferRelease(uniform_buffer);
181+
}
182+
183+
} // namespace
184+
185+
WEBGPU_REGISTER_OPERATORS {
186+
WEBGPU_REGISTER_OP(aten.index.Tensor, index_impl);
187+
}
188+
189+
} // namespace executorch::backends::webgpu
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@group(0) @binding(0) var<storage, read> input: array<f32>;
2+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
3+
@group(0) @binding(2) var<storage, read> index: array<i32>;
4+
5+
struct Params {
6+
numel: u32,
7+
}
8+
@group(0) @binding(3) var<uniform> params: Params;
9+
10+
override wg_size: u32 = 64;
11+
12+
@compute @workgroup_size(wg_size)
13+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
14+
let out_bufi = gid.x;
15+
if (out_bufi >= params.numel) {
16+
return;
17+
}
18+
19+
// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
20+
let i = index[out_bufi];
21+
output[out_bufi] = input[u32(i)];
22+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from index.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: daed48e60bfcf2b7420d277576d794137d3bff383aef4f68464c98c8a7235c8e
17+
inline constexpr const char* kIndexWGSL = R"(
18+
@group(0) @binding(0) var<storage, read> input: array<f32>;
19+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
20+
@group(0) @binding(2) var<storage, read> index: array<i32>;
21+
22+
struct Params {
23+
numel: u32,
24+
}
25+
@group(0) @binding(3) var<uniform> params: Params;
26+
27+
override wg_size: u32 = 64;
28+
29+
@compute @workgroup_size(wg_size)
30+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
31+
let out_bufi = gid.x;
32+
if (out_bufi >= params.numel) {
33+
return;
34+
}
35+
36+
// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
37+
let i = index[out_bufi];
38+
output[out_bufi] = input[u32(i)];
39+
}
40+
)";
41+
42+
inline constexpr uint32_t kIndexWorkgroupSizeX = 64;
43+
inline constexpr uint32_t kIndexWorkgroupSizeY = 1;
44+
inline constexpr uint32_t kIndexWorkgroupSizeZ = 1;
45+
46+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)