Skip to content

[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20722

Merged
5 commits merged into
gh/JulianCloudNTH/74/origfrom
gh/JulianCloudNTH/75/orig
Jul 4, 2026
Merged

[ExecuTorch][WebGPU] 2D compute dispatch — lift the 65535 per-dim cap (prefill path)#20722
5 commits merged into
gh/JulianCloudNTH/74/origfrom
gh/JulianCloudNTH/75/orig

Conversation

@pytorchbot

Copy link
Copy Markdown
Collaborator

This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #20583 by @JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/75/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/75/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/74/orig
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/75/orig

@diff-train-skip-merge

… (prefill path)

Pull Request resolved: #20583

**Lift the 65535 workgroup-per-dim dispatch cap so single-shot SDPA prefill runs at any sequence length.**

**Problem**: The WebGPU backend is 1D-dispatch-only and throws when a kernel's workgroup count exceeds the device per-dim limit (`maxComputeWorkgroupsPerDimension`, spec floor 65535). SDPA prefill QK exceeds it around S~362 (softmax/AV at S=2048), blocking single-shot / long-context prefill.

**Solution**: Fold a >limit 1D workgroup count into 2D; the shader reconstructs the linear index from `@builtin(num_workgroups)`.
- **Before**: `compute_1d_workgroup_count` throws if `count > limit`; dispatch `(count, 1, 1)`.
- **After**: `compute_2d_workgroup_count` returns `{count, 1}` (fast path) or a near-square `{x, y}` (`x = ceil(sqrt(count))` clamped to `limit`, `y = div_up(count, x)`); dispatch `(x, y, 1)`. A flat `{limit, div_up(count, limit)}` split would idle up to ~half the launched workgroups when `count` just exceeds `limit`; the near-square split holds the waste to `O(sqrt(count))` (e.g. 65536 -> `{256, 256}`, 0 inactive).

**Implementation**:
- `WgCount` + pure `fold_workgroup_count_2d` + `compute_2d_workgroup_count` in `WebGPUUtils.h` (device-free, unit-testable; `queried_max_workgroups` factored out of the 1D path)
- `WebGPUDispatch.workgroup_count_y` (default 1, declared last so existing aggregate inits are unchanged); both `dispatchWorkgroups` calls + the profiling record pass `(x, y, 1)`
- Per-kernel in-shader reconstruction: thread-form `idx = gid.x + gid.y*(num_workgroups.x*wg_size)` (QK/AV/add); row-form `row_idx = wid.x + wid.y*num_workgroups.x` (softmax — keeps a `valid` predicate, not an early return, so `workgroupBarrier()`s stay uniform)
- `Sdpa.cpp`: QK/softmax/AV counts via the 2D helper; the dynamic-`input_pos` resize hook recomputes both x and y for QK
- Reference: ET-Vulkan dispatches over natural N-D extents (never folds a flat count nor guards the per-dim limit) and MLX `get_2d_grid_dims` packs whole tensor dims; for our flattened scalar count the near-square split is the correct no-shape-info analog (a pack-to-limit split would reproduce the idle-half waste)

**Constraints**:
- `y=1` fast path keeps every non-folded dispatch byte-identical to the prior 1D path
- Scope = prefill path only; `rms_norm`/`embedding`/`lm_head`/`update_cache` are row/token-indexed and never hit the cap, so they keep the 1D path
- Throws if a 3rd dispatch dimension would be needed — unreachable for real prefill (the `uint32` element guard fires first at S~11585)

Co-authored-with: Claude Code.
ghstack-source-id: 399812920
@exported-using-ghexport

Differential Revision: [D109517684](https://our.internmc.facebook.com/intern/diff/D109517684/)
@pytorch-bot

pytorch-bot Bot commented Jul 4, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20722

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 4, 2026
JCNTH added 4 commits July 4, 2026 10:33
…d unit test

Pull Request resolved: #20584

**Test coverage for the 2D dispatch fold, stacked above the cap-lift op.**

**Problem**: The 2D fold is load-bearing index math — a wrong `{x, y}` means out-of-bounds writes or dropped threads — and the prefill shapes that exercise it previously threw at the 1D cap, so they were untested.

**Solution**: A device-free unit test for the fold arithmetic, plus two single-shot prefill SDPA golden configs that fold each kernel family.
- **Before**: no coverage for >65535-workgroup dispatch; `llama1b_prefill_512`/`_2048` shapes threw at the cap
- **After**: `fold_workgroup_count_2d` unit-tested at the cap boundaries, and the two prefill shapes run as goldens

**Implementation**:
- `test/native/test_dispatch_2d.cpp` — device-free unit test for `utils::fold_workgroup_count_2d`: the 1D fast path, the 2D fold, the real Llama-1B QK counts at S=512 (`{65535, 3}`) and S=2048 (`{65535, 33}`), and the needs-3rd-dimension throw; asserts each `{x, y}` covers `[0, count)`
- `llama1b_prefill_512` + `llama1b_prefill_2048` configs appended to the byte-mirrored `CONFIGS` (`test_sdpa.py`) and `kSdpaConfigs` (`test_webgpu_native.cpp`)
- Registers `webgpu_dispatch_2d_test` in CMake + the native CI script

**Constraints**:
- The Python/C++ config entries byte-mirror each other (kept in sync)
- `add` shares the element-form path with QK, so it is covered structurally; a dedicated >16M-element `add` fold case is omitted as disproportionate

Co-authored-with: Claude Code.
ghstack-source-id: 399812923
@exported-using-ghexport

Differential Revision: [D109517683](https://our.internmc.facebook.com/intern/diff/D109517683/)
Pull Request resolved: #20651

**Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.**

`mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line.

Key changes:
- `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`.
- `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`).

Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`).

Co-authored-with: Claude Code.
ghstack-source-id: 399812930
@exported-using-ghexport

Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
…scripten Dawn

Pull Request resolved: #20652

**Key the `timedWaitAny` instance setup to the actual Dawn API instead of `__EMSCRIPTEN__`, so native-rig Dawn and emscripten/emdawnwebgpu use the modern `requiredFeatures` path and only the vendored Dawn uses the legacy `capabilities.*` path.**

The instance-descriptor setup was guarded by `#if defined(__EMSCRIPTEN__)`, which routed emscripten (emdawnwebgpu, emcc 4.0.19+) through the legacy `capabilities.*` API that no longer exists there. The guard now keys off the API actually present.

Key changes:
- `WebGPUDevice.cpp` — `#if defined(__EMSCRIPTEN__)` → `#if defined(WEBGPU_DAWN_INSTANCE_CAPABILITIES)`. The legacy `instance_desc.capabilities.*` path is taken only by the buck-vendored Dawn (which defines the macro); native cmake Dawn and emscripten leave it undefined and take the `requiredFeatures` / `WGPUInstanceFeatureName_TimedWaitAny` path.

Co-authored-with: Claude Code.
ghstack-source-id: 399812934
@exported-using-ghexport

Differential Revision: [D110149678](https://our.internmc.facebook.com/intern/diff/D110149678/)
Pull Request resolved: #20706

Convert the remaining hand-rolled `int main()` + printf/`bool ok` native tests to GTest so the whole `backends/webgpu/test/` suite is uniform, filterable via `--gtest_filter`, and self-reporting (extends the GTest conversion already applied to `test_dynamic_shape`). The five converted files are a harness-only change — every test case, tensor shape, tolerance, artifact filename, and skip condition is preserved 1:1, only the pass/fail reporting mechanism changes — and this diff additionally wires the already-GTest `webgpu_dynamic_shape_test` into the CI runner so the dynamic-shape suite actually executes.

Key changes:
- `test/test_webgpu_native.cpp`, `test/native/test_dispatch_order.cpp`, `test/native/test_index.cpp`, `test/native/test_scratch_buffer.cpp`, `test/native/test_update_cache.cpp` — `main`+`printf`/`bool ok` accumulator → `TEST()` cases using `EXPECT_*`/`ASSERT_*`; each keeps a custom `main()` that brings up the WebGPU device once then `RUN_ALL_TESTS()` (device-absent still SKIPs by returning 0). `test_index`/`test_webgpu_native` use inclusive `EXPECT_LE(err, tol)` to match the original `err > tol` fail gate exactly.
- `CMakeLists.txt` — move every native-test target into the `if(TARGET GTest::gtest)` block, linking `GTest::gtest`.
- `scripts/test_webgpu_native_ci.sh` — add `-DEXECUTORCH_BUILD_TESTS=ON` to the native-test configure so the now-gtest-gated targets are defined, and wire `webgpu_dynamic_shape_test` into the runner: export its `.pte`s + goldens via `export_dynamic_shape_cases`, add it to the built/run target list behind the same `--target help` probe, and run it guarded (mirroring the `index` test).
- `test/test_build_webgpu.sh` — add `-DEXECUTORCH_BUILD_TESTS=ON` so the local build script (which builds the now-gtest-gated targets unconditionally) still finds them.
ghstack-source-id: 399812941
@exported-using-ghexport

Differential Revision: [D110536636](https://our.internmc.facebook.com/intern/diff/D110536636/)
@ghost ghost requested review from kirklandsign and larryliu0820 as code owners July 4, 2026 17:33
@ghost ghost merged commit f47f1ff into gh/JulianCloudNTH/74/orig Jul 4, 2026
14 of 15 checks passed
@ghost ghost deleted the gh/JulianCloudNTH/75/orig branch July 4, 2026 17:33
@ghost ghost temporarily deployed to cadence July 4, 2026 17:33 — with GitHub Actions Inactive
@ghost ghost temporarily deployed to cadence July 4, 2026 17:33 — with GitHub Actions Inactive
@ghost ghost temporarily deployed to cadence July 4, 2026 18:02 — with GitHub Actions Inactive
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants