Skip to content

Commit 7be6fcd

Browse files
committed
[ExecuTorch][WebGPU] select_copy op test suite (cases.py op-test framework)
Pull Request resolved: #20363 Registers `aten.select_copy.int` in the `cases.py` op-test framework: a `_select_suite` of 4 configs (leading/middle/last dim + negative index) that `generate_op_tests` exports and compares to a torch golden on Dawn. Also adds `test/ops/select/test_select.py` (`SelectModule` + `CONFIGS` + an export-delegation/eager smoke test) and the `aten.select_copy.int` partitioner-allowlist entry in `tester.py`. ghstack-source-id: 397026513 @exported-using-ghexport Differential Revision: [D108793161](https://our.internmc.facebook.com/intern/diff/D108793161/)
1 parent 75e624e commit 7be6fcd

3 files changed

Lines changed: 77 additions & 0 deletions

File tree

backends/webgpu/test/op_tests/cases.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
CONFIGS as _MUL_CONFIGS,
4141
MulModule,
4242
)
43+
from executorch.backends.webgpu.test.ops.test_select import (
44+
CONFIGS as _SELECT_CONFIGS,
45+
SelectModule,
46+
)
4347
from executorch.backends.webgpu.test.ops.test_view_copy import (
4448
CONFIGS as _VIEW_CONFIGS,
4549
ViewModule,
@@ -144,3 +148,8 @@ def _fn_config_suite(module_cls, configs) -> WebGPUTestSuite:
144148
@register_op_test("view_copy")
145149
def _view_copy_suite() -> WebGPUTestSuite:
146150
return _fn_config_suite(ViewModule, _VIEW_CONFIGS)
151+
152+
153+
@register_op_test("select")
154+
def _select_suite() -> WebGPUTestSuite:
155+
return _fn_config_suite(SelectModule, _SELECT_CONFIGS)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""`aten.select_copy.int` module + configs for the WebGPU op-test framework.
8+
9+
`SelectModule` + `CONFIGS` are imported by `cases.py` to drive the declarative
10+
op-test suite. `SelectTest` is the export-delegation smoke test.
11+
Configs cover the leading, middle, and last dim plus a negative index (output rank =
12+
input rank - 1).
13+
"""
14+
15+
import unittest
16+
17+
import torch
18+
19+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
20+
from executorch.exir import to_edge_transform_and_lower
21+
22+
# name -> (input_shape, select_fn)
23+
CONFIGS = {
24+
"dim0": ((3, 8, 4), lambda x: x[1]),
25+
"middle": ((3, 8, 4), lambda x: x[:, 2]),
26+
"last": ((3, 8, 4), lambda x: x[..., 3]),
27+
"neg_idx": ((3, 8, 4), lambda x: x[:, -1]),
28+
}
29+
30+
31+
class SelectModule(torch.nn.Module):
32+
def __init__(self, fn):
33+
super().__init__()
34+
self.fn = fn
35+
36+
def forward(self, x: torch.Tensor) -> torch.Tensor:
37+
return self.fn(x)
38+
39+
40+
def _det_input(shape):
41+
g = torch.Generator().manual_seed(0)
42+
return torch.randn(*shape, generator=g, dtype=torch.float32)
43+
44+
45+
def _export(fn, x: torch.Tensor):
46+
ep = torch.export.export(SelectModule(fn).eval(), (x,))
47+
return to_edge_transform_and_lower(
48+
ep, partitioner=[VulkanPartitioner()]
49+
).to_executorch()
50+
51+
52+
def _delegated(et) -> bool:
53+
return any(
54+
d.id == "VulkanBackend"
55+
for plan in et.executorch_program.execution_plan
56+
for d in plan.delegates
57+
)
58+
59+
60+
class SelectTest(unittest.TestCase):
61+
def test_export_delegates(self) -> None:
62+
for name, (shape, fn) in CONFIGS.items():
63+
with self.subTest(name=name):
64+
et = _export(fn, _det_input(shape))
65+
self.assertTrue(
66+
_delegated(et), f"Expected a VulkanBackend delegate (select {name})"
67+
)

backends/webgpu/test/tester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
exir_ops.edge.et_vk.rms_norm.default,
2424
exir_ops.edge.aten.mul.Tensor,
2525
exir_ops.edge.aten.view_copy.default,
26+
exir_ops.edge.aten.select_copy.int,
2627
]
2728

2829

0 commit comments

Comments
 (0)