Skip to content

Commit 11567a3

Browse files
[ExecuTorch][WebGPU] Add permute_copy + IntList graph support (aten.permute_copy.default) (#20396)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.15.0) (oldest at bottom): * #20465 * #20464 * #20463 * #20435 * #20399 * #20398 * #20397 * __->__ #20396 * #20395 * #20394 * #20393 * #20392 * #20391 * #20390 * #20363 * #20362 * #20361 * #20360 * #20359 Adds `aten.permute_copy.default` (a coordinate-reorder gather) to the WebGPU delegate, and the `IntList` graph value type it needs to read its `dims` argument. Composition: - `runtime/WebGPUGraph.{h,cpp}` — adds `ValueType::IntList` backed by `std::vector<std::vector<int64_t>> int_lists_` + `get_int_list(int)`; `build()` deserializes `vkgraph::GraphTypes::IntList` via `value_as_IntList()->items()` (int64, matching the FlatBuffer `[long]`); mirrors the existing scalar value plumbing. - `runtime/ops/permute/Permute.cpp` — reads the permutation via `get_int_list`, normalizes negative dims, validates it is a permutation of `[0, ndim)`, builds two `TensorMeta` UBOs + a `PermuteParams{perm: vec4<u32>}` uniform, guards fp32 + rank≤4, dispatches over `compute_1d_workgroup_count(out.numel)` with `override wg_size`; releases all uniforms after the bind group. - `runtime/ops/permute/permute.wgsl` — delinearizes the output index over the contiguous output strides, reads `input` at `in.strides[perm[d]]` per dim (mirrors Vulkan `permute_buffer.glsl`). - Registers both `aten.permute_copy.default` and `aten.permute.default` to the same handler. @exported-using-ghexport Differential Revision: [D108793162](https://our.internmc.facebook.com/intern/diff/D108793162/) Differential Revision: [D108793162](https://our.internmc.facebook.com/intern/diff/D108793162)
1 parent a67f12e commit 11567a3

6 files changed

Lines changed: 311 additions & 1 deletion

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ set(WEBGPU_SRCS
4848
runtime/ops/squeeze/Squeeze.cpp
4949
runtime/ops/unsqueeze/Unsqueeze.cpp
5050
runtime/ops/slice/Slice.cpp
51+
runtime/ops/permute/Permute.cpp
5152
)
5253

5354
add_library(webgpu_backend ${WEBGPU_SRCS})

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ void WebGPUGraph::build(
245245
tensors_.resize(num_vals);
246246
tensor_mem_obj_ids_.resize(num_vals, -1);
247247
ints_.resize(num_vals, 0);
248+
int_lists_.resize(num_vals);
248249
doubles_.resize(num_vals, 0.0);
249250
bools_.resize(num_vals, false);
250251
value_lists_.resize(num_vals);
@@ -375,6 +376,14 @@ void WebGPUGraph::build(
375376
ints_[i] = val->value_as_Int()->int_val();
376377
break;
377378
}
379+
case vkgraph::GraphTypes::IntList: {
380+
value_types_[i] = ValueType::IntList;
381+
const auto* items = val->value_as_IntList()->items();
382+
if (items) {
383+
int_lists_[i].assign(items->cbegin(), items->cend());
384+
}
385+
break;
386+
}
378387
case vkgraph::GraphTypes::Double: {
379388
value_types_[i] = ValueType::Double;
380389
doubles_[i] = val->value_as_Double()->double_val();

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ class WebGPUGraph {
131131
int64_t get_int(int id) const {
132132
return ints_[id];
133133
}
134+
// Int values of a serialized IntList (e.g. permute dims). int64 (FlatBuffer
135+
// [long]) to match the schema and the get_int convention.
136+
const std::vector<int64_t>& get_int_list(int id) const {
137+
return int_lists_[id];
138+
}
134139
bool get_bool(int id) const {
135140
return bools_[id];
136141
}
@@ -258,7 +263,8 @@ class WebGPUGraph {
258263
Null,
259264
String,
260265
SymInt,
261-
ValueList
266+
ValueList,
267+
IntList
262268
};
263269

264270
ValueType get_value_type(int id) const {
@@ -275,6 +281,7 @@ class WebGPUGraph {
275281
std::vector<ValueType> value_types_;
276282
std::vector<WebGPUTensor> tensors_;
277283
std::vector<int64_t> ints_;
284+
std::vector<std::vector<int64_t>> int_lists_;
278285
std::vector<double> doubles_;
279286
std::vector<bool> bools_;
280287
std::vector<std::vector<int>> value_lists_;
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
11+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
12+
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
13+
#include <executorch/backends/webgpu/runtime/ops/permute/permute_wgsl.h>
14+
15+
#include <webgpu/webgpu.h>
16+
17+
#include <cstdint>
18+
#include <cstring>
19+
#include <stdexcept>
20+
#include <vector>
21+
22+
namespace executorch::backends::webgpu {
23+
24+
namespace {
25+
26+
struct PermuteParams {
27+
uint32_t perm[kTensorMetaMaxNdim];
28+
};
29+
static_assert(
30+
sizeof(PermuteParams) == 16,
31+
"PermuteParams must match the WGSL Params vec4<u32> (16 bytes)");
32+
33+
// permute: out coord d -> in coord perm[d] (Vulkan permute_buffer.glsl, NCHW).
34+
void permute_impl(WebGPUGraph& graph, const std::vector<int>& args) {
35+
// args: [self, dims, out]; out is the last value-id.
36+
const int in_id = args.at(0);
37+
const int dims_id = args.at(1);
38+
const int out_id = args.at(args.size() - 1);
39+
40+
if (graph.get_value_type(in_id) != WebGPUGraph::ValueType::Tensor ||
41+
graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) {
42+
throw std::runtime_error("permute: in/out arg is not a tensor");
43+
}
44+
if (graph.get_value_type(dims_id) != WebGPUGraph::ValueType::IntList) {
45+
throw std::runtime_error("permute: dims arg is not an IntList");
46+
}
47+
48+
WGPUDevice device = graph.device();
49+
const auto& in_tensor = graph.get_tensor(in_id);
50+
const auto& out_tensor = graph.get_tensor(out_id);
51+
const int ndim = static_cast<int>(in_tensor.dims.size());
52+
53+
const std::vector<int64_t>& dims = graph.get_int_list(dims_id);
54+
if (static_cast<int>(dims.size()) != ndim ||
55+
static_cast<int>(out_tensor.dims.size()) != ndim) {
56+
throw std::runtime_error("permute: perm length != input/output rank");
57+
}
58+
59+
// Normalize negative dims and verify perm is a permutation of [0, ndim).
60+
uint32_t perm[kTensorMetaMaxNdim];
61+
bool seen[kTensorMetaMaxNdim] = {};
62+
if (ndim > static_cast<int>(kTensorMetaMaxNdim)) {
63+
throw std::runtime_error("permute: tensor rank exceeds 4 (MAX_NDIM)");
64+
}
65+
for (int d = 0; d < ndim; d++) {
66+
int64_t p = dims[d];
67+
if (p < 0) {
68+
p += ndim;
69+
}
70+
if (p < 0 || p >= ndim || seen[p]) {
71+
throw std::runtime_error("permute: dims is not a valid permutation");
72+
}
73+
seen[p] = true;
74+
perm[d] = static_cast<uint32_t>(p);
75+
}
76+
for (int d = ndim; d < static_cast<int>(kTensorMetaMaxNdim); d++) {
77+
perm[d] = static_cast<uint32_t>(d);
78+
}
79+
80+
TensorMeta out_meta;
81+
TensorMeta in_meta;
82+
fill_tensor_meta(out_tensor, &out_meta);
83+
fill_tensor_meta(in_tensor, &in_meta);
84+
if (out_tensor.nbytes !=
85+
static_cast<size_t>(out_meta.numel) * sizeof(float) ||
86+
in_tensor.nbytes != static_cast<size_t>(in_meta.numel) * sizeof(float)) {
87+
throw std::runtime_error("permute: non-fp32 operand (nbytes != numel * 4)");
88+
}
89+
90+
PermuteParams params = {};
91+
std::memcpy(params.perm, perm, sizeof(perm));
92+
93+
uint32_t wg_size =
94+
utils::clamp_workgroup_size(device, kPermuteWorkgroupSizeX);
95+
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
96+
device, out_meta.numel, wg_size, "permute");
97+
98+
WGPUConstantEntry wg_size_constant = {};
99+
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
100+
wg_size_constant.value = static_cast<double>(wg_size);
101+
102+
WGPUBuffer out_meta_buf =
103+
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
104+
WGPUBuffer in_meta_buf =
105+
utils::make_uniform(device, &in_meta, sizeof(TensorMeta));
106+
WGPUBuffer params_buf =
107+
utils::make_uniform(device, &params, sizeof(PermuteParams));
108+
graph.add_uniform_buffer_bytes(
109+
2 * sizeof(TensorMeta) + sizeof(PermuteParams));
110+
111+
WGPUShaderSourceWGSL wgsl_desc = {};
112+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
113+
wgsl_desc.code = {kPermuteWGSL, WGPU_STRLEN};
114+
WGPUShaderModuleDescriptor shader_desc = {};
115+
shader_desc.nextInChain = &wgsl_desc.chain;
116+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
117+
118+
// Bind group: in, out (rw), out_meta, in_meta, params (3 uniforms).
119+
WGPUBindGroupLayoutEntry entries[5] = {};
120+
entries[0].binding = 0;
121+
entries[0].visibility = WGPUShaderStage_Compute;
122+
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
123+
entries[1].binding = 1;
124+
entries[1].visibility = WGPUShaderStage_Compute;
125+
entries[1].buffer.type = WGPUBufferBindingType_Storage;
126+
entries[2].binding = 2;
127+
entries[2].visibility = WGPUShaderStage_Compute;
128+
entries[2].buffer.type = WGPUBufferBindingType_Uniform;
129+
entries[3].binding = 3;
130+
entries[3].visibility = WGPUShaderStage_Compute;
131+
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
132+
entries[4].binding = 4;
133+
entries[4].visibility = WGPUShaderStage_Compute;
134+
entries[4].buffer.type = WGPUBufferBindingType_Uniform;
135+
136+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
137+
bgl_desc.entryCount = 5;
138+
bgl_desc.entries = entries;
139+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
140+
141+
WGPUPipelineLayoutDescriptor pl_desc = {};
142+
pl_desc.bindGroupLayoutCount = 1;
143+
pl_desc.bindGroupLayouts = &bgl;
144+
WGPUPipelineLayout pipeline_layout =
145+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
146+
147+
WGPUComputePipelineDescriptor pipeline_desc = {};
148+
pipeline_desc.layout = pipeline_layout;
149+
pipeline_desc.compute.module = shader;
150+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
151+
pipeline_desc.compute.constantCount = 1;
152+
pipeline_desc.compute.constants = &wg_size_constant;
153+
WGPUComputePipeline pipeline =
154+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
155+
156+
WGPUBindGroupEntry bg_entries[5] = {};
157+
bg_entries[0].binding = 0;
158+
bg_entries[0].buffer = in_tensor.buffer;
159+
bg_entries[0].size = in_tensor.nbytes;
160+
bg_entries[1].binding = 1;
161+
bg_entries[1].buffer = out_tensor.buffer;
162+
bg_entries[1].size = out_tensor.nbytes;
163+
bg_entries[2].binding = 2;
164+
bg_entries[2].buffer = out_meta_buf;
165+
bg_entries[2].size = sizeof(TensorMeta);
166+
bg_entries[3].binding = 3;
167+
bg_entries[3].buffer = in_meta_buf;
168+
bg_entries[3].size = sizeof(TensorMeta);
169+
bg_entries[4].binding = 4;
170+
bg_entries[4].buffer = params_buf;
171+
bg_entries[4].size = sizeof(PermuteParams);
172+
173+
WGPUBindGroupDescriptor bg_desc = {};
174+
bg_desc.layout = bgl;
175+
bg_desc.entryCount = 5;
176+
bg_desc.entries = bg_entries;
177+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
178+
179+
graph.add_dispatch({pipeline, bind_group, workgroup_count});
180+
181+
wgpuShaderModuleRelease(shader);
182+
wgpuBindGroupLayoutRelease(bgl);
183+
wgpuPipelineLayoutRelease(pipeline_layout);
184+
// Drop our refs; the bind group keeps the uniforms alive until release.
185+
wgpuBufferRelease(out_meta_buf);
186+
wgpuBufferRelease(in_meta_buf);
187+
wgpuBufferRelease(params_buf);
188+
}
189+
190+
} // namespace
191+
192+
WEBGPU_REGISTER_OPERATORS {
193+
WEBGPU_REGISTER_OP(aten.permute_copy.default, permute_impl);
194+
WEBGPU_REGISTER_OP(aten.permute.default, permute_impl);
195+
}
196+
197+
} // namespace executorch::backends::webgpu
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
@group(0) @binding(0) var<storage, read> input: array<f32>;
2+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
3+
4+
struct TensorMeta {
5+
ndim: u32,
6+
numel: u32,
7+
sizes: vec4<u32>,
8+
strides: vec4<u32>,
9+
}
10+
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
11+
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;
12+
13+
struct Params {
14+
perm: vec4<u32>,
15+
}
16+
@group(0) @binding(4) var<uniform> params: Params;
17+
18+
override wg_size: u32 = 64u;
19+
20+
@compute @workgroup_size(wg_size, 1, 1)
21+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
22+
let out_bufi = gid.x;
23+
if (out_bufi >= out_meta.numel) {
24+
return;
25+
}
26+
27+
// Gather: out coord d -> in coord perm[d] (Vulkan permute_buffer.glsl).
28+
var rem = out_bufi;
29+
var in_bufi: u32 = 0u;
30+
for (var d: u32 = 0u; d < out_meta.ndim; d = d + 1u) {
31+
let coord = rem / out_meta.strides[d];
32+
rem = rem % out_meta.strides[d];
33+
in_bufi = in_bufi + coord * in_meta.strides[params.perm[d]];
34+
}
35+
output[out_bufi] = input[in_bufi];
36+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from permute.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: d34f59730cda7317589b6ed5691a1ccab8666b9c94e17ac2cb3658b036300197
17+
inline constexpr const char* kPermuteWGSL = R"(
18+
@group(0) @binding(0) var<storage, read> input: array<f32>;
19+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
20+
21+
struct TensorMeta {
22+
ndim: u32,
23+
numel: u32,
24+
sizes: vec4<u32>,
25+
strides: vec4<u32>,
26+
}
27+
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
28+
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;
29+
30+
struct Params {
31+
perm: vec4<u32>,
32+
}
33+
@group(0) @binding(4) var<uniform> params: Params;
34+
35+
override wg_size: u32 = 64u;
36+
37+
@compute @workgroup_size(wg_size, 1, 1)
38+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
39+
let out_bufi = gid.x;
40+
if (out_bufi >= out_meta.numel) {
41+
return;
42+
}
43+
44+
// Gather: out coord d -> in coord perm[d] (Vulkan permute_buffer.glsl).
45+
var rem = out_bufi;
46+
var in_bufi: u32 = 0u;
47+
for (var d: u32 = 0u; d < out_meta.ndim; d = d + 1u) {
48+
let coord = rem / out_meta.strides[d];
49+
rem = rem % out_meta.strides[d];
50+
in_bufi = in_bufi + coord * in_meta.strides[params.perm[d]];
51+
}
52+
output[out_bufi] = input[in_bufi];
53+
}
54+
)";
55+
56+
inline constexpr uint32_t kPermuteWorkgroupSizeX = 64;
57+
inline constexpr uint32_t kPermuteWorkgroupSizeY = 1;
58+
inline constexpr uint32_t kPermuteWorkgroupSizeZ = 1;
59+
60+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)