Skip to content

Commit 31eb102

Browse files
[ExecuTorch][WebGPU] Add select_copy op (aten.select_copy.int)
Pull Request resolved: #20362 Adds `aten.select_copy.int` to the WebGPU delegate as a gather: picks a fixed index along one dim, producing an output of rank (input rank - 1). Composition (single dispatch): - `select/Select.cpp` — reads `[self, dim, index, out]` (static `Int` via `read_scalar`; throws on dynamic `SymInt`), normalizes + bounds-checks dim/index, builds 2 `TensorMeta` UBOs + a `SelectParams{dim,index}`, fp32 guard, 1D-dispatch over `numel`, releases uniforms after the bind group. - `select/select.wgsl` — seeds the input offset with `index * in.strides[dim]`, delinearizes the output index, maps each out dim to its in dim (shifted past the selected dim), relinearizes on input strides. ghstack-source-id: 397026510 @exported-using-ghexport Differential Revision: [D108793166](https://our.internmc.facebook.com/intern/diff/D108793166/)
1 parent 3d36ada commit 31eb102

4 files changed

Lines changed: 291 additions & 0 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ set(WEBGPU_SRCS
4343
runtime/ops/rope/RotaryEmbedding.cpp
4444
runtime/ops/prepack/Prepack.cpp
4545
runtime/ops/view_copy/ViewCopy.cpp
46+
runtime/ops/select/Select.cpp
4647
)
4748

4849
add_library(webgpu_backend ${WEBGPU_SRCS})
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
11+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
12+
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
13+
#include <executorch/backends/webgpu/runtime/ops/select/select_wgsl.h>
14+
15+
#include <webgpu/webgpu.h>
16+
17+
#include <cstdint>
18+
#include <stdexcept>
19+
#include <string>
20+
#include <vector>
21+
22+
namespace executorch::backends::webgpu {
23+
24+
namespace {
25+
26+
struct SelectParams {
27+
uint32_t dim;
28+
uint32_t index;
29+
uint32_t _pad[2];
30+
};
31+
32+
// dim/index are required Ints (SymInt throws); no Null default unlike slice.
33+
int64_t read_scalar(WebGPUGraph& graph, int id, const char* what) {
34+
if (graph.get_value_type(id) == WebGPUGraph::ValueType::Int) {
35+
return graph.get_int(id);
36+
}
37+
throw std::runtime_error(std::string("select: dynamic/unsupported ") + what);
38+
}
39+
40+
void select_impl(WebGPUGraph& graph, const std::vector<int>& args) {
41+
// args: [self, dim, index, out]; output rank = in rank - 1.
42+
const int in_id = args.at(0);
43+
const int out_id = args.at(3);
44+
45+
WGPUDevice device = graph.device();
46+
const auto& in_tensor = graph.get_tensor(in_id);
47+
const auto& out_tensor = graph.get_tensor(out_id);
48+
if (in_tensor.buffer == nullptr || out_tensor.buffer == nullptr) {
49+
throw std::runtime_error("select: null buffer binding");
50+
}
51+
52+
const int in_ndim = static_cast<int>(in_tensor.dims.size());
53+
int64_t dim = read_scalar(graph, args.at(1), "dim");
54+
if (dim < 0) {
55+
dim += in_ndim;
56+
}
57+
if (dim < 0 || dim >= in_ndim) {
58+
throw std::runtime_error("select: dim out of range");
59+
}
60+
const int64_t in_size = in_tensor.dims[dim];
61+
int64_t index = read_scalar(graph, args.at(2), "index");
62+
if (index < 0) {
63+
index += in_size;
64+
}
65+
if (index < 0 || index >= in_size) {
66+
throw std::runtime_error("select: index out of range");
67+
}
68+
69+
TensorMeta out_meta;
70+
TensorMeta in_meta;
71+
fill_tensor_meta(out_tensor, &out_meta);
72+
fill_tensor_meta(in_tensor, &in_meta);
73+
if (out_tensor.nbytes !=
74+
static_cast<size_t>(out_meta.numel) * sizeof(float) ||
75+
in_tensor.nbytes != static_cast<size_t>(in_meta.numel) * sizeof(float)) {
76+
throw std::runtime_error("select: non-fp32 operand (nbytes != numel * 4)");
77+
}
78+
79+
SelectParams params = {};
80+
params.dim = static_cast<uint32_t>(dim);
81+
params.index = static_cast<uint32_t>(index);
82+
83+
uint32_t wg_size = utils::clamp_workgroup_size(device, kSelectWorkgroupSizeX);
84+
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
85+
device, out_meta.numel, wg_size, "select");
86+
87+
WGPUConstantEntry wg_size_constant = {};
88+
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
89+
wg_size_constant.value = static_cast<double>(wg_size);
90+
91+
WGPUBuffer out_meta_buf =
92+
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
93+
WGPUBuffer in_meta_buf =
94+
utils::make_uniform(device, &in_meta, sizeof(TensorMeta));
95+
WGPUBuffer params_buf =
96+
utils::make_uniform(device, &params, sizeof(SelectParams));
97+
graph.add_uniform_buffer_bytes(2 * sizeof(TensorMeta) + sizeof(SelectParams));
98+
99+
WGPUShaderSourceWGSL wgsl_desc = {};
100+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
101+
wgsl_desc.code = {kSelectWGSL, WGPU_STRLEN};
102+
WGPUShaderModuleDescriptor shader_desc = {};
103+
shader_desc.nextInChain = &wgsl_desc.chain;
104+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
105+
106+
// Bind group: in, out (rw), out_meta, in_meta, params (3 uniforms).
107+
WGPUBindGroupLayoutEntry entries[5] = {};
108+
entries[0].binding = 0;
109+
entries[0].visibility = WGPUShaderStage_Compute;
110+
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
111+
entries[1].binding = 1;
112+
entries[1].visibility = WGPUShaderStage_Compute;
113+
entries[1].buffer.type = WGPUBufferBindingType_Storage;
114+
entries[2].binding = 2;
115+
entries[2].visibility = WGPUShaderStage_Compute;
116+
entries[2].buffer.type = WGPUBufferBindingType_Uniform;
117+
entries[3].binding = 3;
118+
entries[3].visibility = WGPUShaderStage_Compute;
119+
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
120+
entries[4].binding = 4;
121+
entries[4].visibility = WGPUShaderStage_Compute;
122+
entries[4].buffer.type = WGPUBufferBindingType_Uniform;
123+
124+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
125+
bgl_desc.entryCount = 5;
126+
bgl_desc.entries = entries;
127+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
128+
129+
WGPUPipelineLayoutDescriptor pl_desc = {};
130+
pl_desc.bindGroupLayoutCount = 1;
131+
pl_desc.bindGroupLayouts = &bgl;
132+
WGPUPipelineLayout pipeline_layout =
133+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
134+
135+
WGPUComputePipelineDescriptor pipeline_desc = {};
136+
pipeline_desc.layout = pipeline_layout;
137+
pipeline_desc.compute.module = shader;
138+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
139+
pipeline_desc.compute.constantCount = 1;
140+
pipeline_desc.compute.constants = &wg_size_constant;
141+
WGPUComputePipeline pipeline =
142+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
143+
144+
WGPUBindGroupEntry bg_entries[5] = {};
145+
bg_entries[0].binding = 0;
146+
bg_entries[0].buffer = in_tensor.buffer;
147+
bg_entries[0].size = in_tensor.nbytes;
148+
bg_entries[1].binding = 1;
149+
bg_entries[1].buffer = out_tensor.buffer;
150+
bg_entries[1].size = out_tensor.nbytes;
151+
bg_entries[2].binding = 2;
152+
bg_entries[2].buffer = out_meta_buf;
153+
bg_entries[2].size = sizeof(TensorMeta);
154+
bg_entries[3].binding = 3;
155+
bg_entries[3].buffer = in_meta_buf;
156+
bg_entries[3].size = sizeof(TensorMeta);
157+
bg_entries[4].binding = 4;
158+
bg_entries[4].buffer = params_buf;
159+
bg_entries[4].size = sizeof(SelectParams);
160+
161+
WGPUBindGroupDescriptor bg_desc = {};
162+
bg_desc.layout = bgl;
163+
bg_desc.entryCount = 5;
164+
bg_desc.entries = bg_entries;
165+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
166+
167+
graph.add_dispatch({pipeline, bind_group, workgroup_count});
168+
169+
wgpuShaderModuleRelease(shader);
170+
wgpuBindGroupLayoutRelease(bgl);
171+
wgpuPipelineLayoutRelease(pipeline_layout);
172+
// Drop our refs; the bind group keeps the uniforms alive until release.
173+
wgpuBufferRelease(out_meta_buf);
174+
wgpuBufferRelease(in_meta_buf);
175+
wgpuBufferRelease(params_buf);
176+
}
177+
178+
} // namespace
179+
180+
WEBGPU_REGISTER_OPERATORS {
181+
WEBGPU_REGISTER_OP(aten.select_copy.int, select_impl);
182+
}
183+
184+
} // namespace executorch::backends::webgpu
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
@group(0) @binding(0) var<storage, read> input: array<f32>;
2+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
3+
4+
struct TensorMeta {
5+
ndim: u32,
6+
numel: u32,
7+
sizes: vec4<u32>,
8+
strides: vec4<u32>,
9+
}
10+
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
11+
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;
12+
13+
struct Params {
14+
dim: u32,
15+
index: u32,
16+
}
17+
@group(0) @binding(4) var<uniform> params: Params;
18+
19+
override wg_size: u32 = 64u;
20+
21+
@compute @workgroup_size(wg_size, 1, 1)
22+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
23+
let out_bufi = gid.x;
24+
if (out_bufi >= out_meta.numel) {
25+
return;
26+
}
27+
28+
// Gather: out dim od -> in dim (od if od < dim else od+1); sel dim = index.
29+
var rem = out_bufi;
30+
var in_bufi: u32 = params.index * in_meta.strides[params.dim];
31+
for (var od: u32 = 0u; od < out_meta.ndim; od = od + 1u) {
32+
let coord = rem / out_meta.strides[od];
33+
rem = rem % out_meta.strides[od];
34+
var id = od;
35+
if (od >= params.dim) {
36+
id = od + 1u;
37+
}
38+
in_bufi = in_bufi + coord * in_meta.strides[id];
39+
}
40+
output[out_bufi] = input[in_bufi];
41+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from select.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: 200cf5a8190045aa0562e782f01c1cfaf9681f30f679f5112ccc3d347a0ed8df
17+
inline constexpr const char* kSelectWGSL = R"(
18+
@group(0) @binding(0) var<storage, read> input: array<f32>;
19+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
20+
21+
struct TensorMeta {
22+
ndim: u32,
23+
numel: u32,
24+
sizes: vec4<u32>,
25+
strides: vec4<u32>,
26+
}
27+
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
28+
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;
29+
30+
struct Params {
31+
dim: u32,
32+
index: u32,
33+
}
34+
@group(0) @binding(4) var<uniform> params: Params;
35+
36+
override wg_size: u32 = 64u;
37+
38+
@compute @workgroup_size(wg_size, 1, 1)
39+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
40+
let out_bufi = gid.x;
41+
if (out_bufi >= out_meta.numel) {
42+
return;
43+
}
44+
45+
// Gather: out dim od -> in dim (od if od < dim else od+1); sel dim = index.
46+
var rem = out_bufi;
47+
var in_bufi: u32 = params.index * in_meta.strides[params.dim];
48+
for (var od: u32 = 0u; od < out_meta.ndim; od = od + 1u) {
49+
let coord = rem / out_meta.strides[od];
50+
rem = rem % out_meta.strides[od];
51+
var id = od;
52+
if (od >= params.dim) {
53+
id = od + 1u;
54+
}
55+
in_bufi = in_bufi + coord * in_meta.strides[id];
56+
}
57+
output[out_bufi] = input[in_bufi];
58+
}
59+
)";
60+
61+
inline constexpr uint32_t kSelectWorkgroupSizeX = 64;
62+
inline constexpr uint32_t kSelectWorkgroupSizeY = 1;
63+
inline constexpr uint32_t kSelectWorkgroupSizeZ = 1;
64+
65+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)