Skip to content

Commit 3d20719

Browse files
committed
[ExecuTorch][WebGPU] 2D compute dispatch tests — prefill golden + fold unit test
Pull Request resolved: #20584 **Test coverage for the 2D dispatch fold, stacked above the cap-lift op.** **Problem**: The 2D fold is load-bearing index math — a wrong `{x, y}` means out-of-bounds writes or dropped threads — and the prefill shapes that exercise it previously threw at the 1D cap, so they were untested. **Solution**: A device-free unit test for the fold arithmetic, plus two single-shot prefill SDPA golden configs that fold each kernel family. - **Before**: no coverage for >65535-workgroup dispatch; `llama1b_prefill_512`/`_2048` shapes threw at the cap - **After**: `fold_workgroup_count_2d` unit-tested at the cap boundaries, and the two prefill shapes run as goldens **Implementation**: - `test/native/test_dispatch_2d.cpp` — device-free unit test for `utils::fold_workgroup_count_2d`: the 1D fast path, the 2D fold, the real Llama-1B QK counts at S=512 (`{65535, 3}`) and S=2048 (`{65535, 33}`), and the needs-3rd-dimension throw; asserts each `{x, y}` covers `[0, count)` - `llama1b_prefill_512` + `llama1b_prefill_2048` configs appended to the byte-mirrored `CONFIGS` (`test_sdpa.py`) and `kSdpaConfigs` (`test_webgpu_native.cpp`) - Registers `webgpu_dispatch_2d_test` in CMake + the native CI script **Constraints**: - The Python/C++ config entries byte-mirror each other (kept in sync) - `add` shares the element-form path with QK, so it is covered structurally; a dedicated >16M-element `add` fold case is omitted as disproportionate Co-authored-with: Claude Code. ghstack-source-id: 399812923 @exported-using-ghexport Differential Revision: [D109517683](https://our.internmc.facebook.com/intern/diff/D109517683/)
1 parent 0766cb9 commit 3d20719

5 files changed

Lines changed: 86 additions & 1 deletion

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
201201
webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp
202202
)
203203
target_link_libraries(webgpu_dynamic_shape_test PRIVATE GTest::gtest)
204+
205+
# Device-free fold unit test (gtest_main provides main; no device needed).
206+
add_webgpu_native_test(
207+
webgpu_dispatch_2d_test test/native/test_dispatch_2d.cpp
208+
)
209+
target_link_libraries(
210+
webgpu_dispatch_2d_test PRIVATE GTest::gtest GTest::gtest_main
211+
)
204212
endif()
205213
add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp)
206214
endif()

backends/webgpu/scripts/test_webgpu_native_ci.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ cmake \
143143
"${EXECUTORCH_ROOT}"
144144

145145
# ── Build + run every native test target that exists in this tree ────────────
146-
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test)
146+
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test webgpu_dispatch_2d_test)
147147
BIN_DIR="${BUILD_DIR}/backends/webgpu"
148148

149149
# Which targets are defined depends on which diffs are landed (native_test +
@@ -212,6 +212,8 @@ if [[ "${INDEX_OK}" == "1" && -x "${BIN_DIR}/webgpu_index_test" ]]; then
212212
"${BIN_DIR}/webgpu_index_test" "${INDEX_DIR}"
213213
fi
214214
[[ -x "${BIN_DIR}/webgpu_scratch_buffer_test" ]] && "${BIN_DIR}/webgpu_scratch_buffer_test"
215+
# Device-free: pure 2D workgroup-count fold unit test (no .pte, no GPU).
216+
[[ -x "${BIN_DIR}/webgpu_dispatch_2d_test" ]] && "${BIN_DIR}/webgpu_dispatch_2d_test"
215217

216218
echo "=== WebGPU native tests on Dawn: all run targets passed ==="
217219

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Device-free unit test for the pure 2D workgroup-count fold that lifts the
10+
// 65535 per-dim dispatch cap. Exercises the fold arithmetic only — no GPU.
11+
12+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
13+
14+
#include <gtest/gtest.h>
15+
16+
#include <cmath>
17+
#include <cstdint>
18+
19+
using executorch::backends::webgpu::utils::fold_workgroup_count_2d;
20+
using executorch::backends::webgpu::utils::WgCount;
21+
22+
namespace {
23+
24+
constexpr uint32_t kMax = 65535u;
25+
26+
// count <= max -> {count, 1}: the 1D fast path, byte-identical to the old path.
27+
TEST(DispatchFold, FastPath1D) {
28+
for (uint32_t count : {1u, kMax - 1u, kMax}) {
29+
const WgCount got = fold_workgroup_count_2d(count, kMax, "test");
30+
EXPECT_EQ(got.x, count);
31+
EXPECT_EQ(got.y, 1u);
32+
}
33+
}
34+
35+
// count > max -> near-square {x, y}: fits the per-dim cap, covers every
36+
// workgroup, and stays near-square so few invocations are inactive (launched -
37+
// count is O(sqrt(count)); a flat {max, div_up} split would idle up to ~half).
38+
TEST(DispatchFold, NearSquareFold) {
39+
// Includes prefill-scale QK counts (Hq*ceil(S/4)*ceil(ctx/4)/wg) that fold:
40+
// 131072 = S=2048 (32*512*512/64); 2097152 = large-S stress.
41+
for (uint32_t count :
42+
{kMax + 1u, 2u * kMax, 2u * kMax + 1u, 131072u, 2097152u}) {
43+
const WgCount got = fold_workgroup_count_2d(count, kMax, "test");
44+
const uint64_t launched = static_cast<uint64_t>(got.x) * got.y;
45+
const uint32_t root =
46+
static_cast<uint32_t>(std::ceil(std::sqrt(static_cast<double>(count))));
47+
EXPECT_LE(got.x, kMax) << "count=" << count;
48+
EXPECT_LE(got.y, kMax) << "count=" << count;
49+
EXPECT_GE(launched, count) << "count=" << count;
50+
EXPECT_LT(launched - count, 2ull * root)
51+
<< "count=" << count << " launched=" << launched;
52+
}
53+
}
54+
55+
// count > max^2 needs a 3rd dispatch dimension -> throws (out of scope).
56+
TEST(DispatchFold, ThrowsWhenNeeds3rdDimension) {
57+
EXPECT_ANY_THROW(fold_workgroup_count_2d(kMax * kMax + 1u, kMax, "test"));
58+
}
59+
60+
} // namespace

backends/webgpu/test/ops/test_sdpa.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class SdpaConfig:
6161
SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127),
6262
# D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load.
6363
SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0),
64+
# 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV (cap+1).
65+
SdpaConfig("llama1b_prefill_512", 32, 8, 64, 512, 512, 0),
66+
SdpaConfig("llama1b_prefill_2048", 32, 8, 64, 2048, 2048, 0),
6467
]
6568

6669

backends/webgpu/test/test_webgpu_native.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,18 @@ static const SdpaConfig kSdpaConfigs[] = {
758758
16.0f,
759759
/*required=*/false,
760760
/*expect_reject=*/true},
761+
// 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV
762+
// (cap+1).
763+
{"llama1b_prefill_512", 32, 8, 64, 512, 512, 0, 16.0f, /*required=*/true},
764+
{"llama1b_prefill_2048",
765+
32,
766+
8,
767+
64,
768+
2048,
769+
2048,
770+
0,
771+
16.0f,
772+
/*required=*/true},
761773
};
762774

763775
// Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync).

0 commit comments

Comments
 (0)