[ExecuTorch][WebGPU] 2D-fold mul + permute dispatch (lift 65535 1D cap)#20651
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20651
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 91b2d97 with merge base 124b677 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude Review thoroughly. Check for stale code / bugs. |
|
Claude finished @JulianCloudNTH's task in 2m 47s —— View job Review: WebGPU 2D-fold for
|
1055559
into
gh/JulianCloudNTH/83/base
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Pull Request resolved: #20651 **Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.** `mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line. Key changes: - `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`. - `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`). Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`). Co-authored-with: Claude Code. ghstack-source-id: 399812930 @exported-using-ghexport Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
Stack from ghstack (oldest at bottom):
Lift the 65535 workgroup-per-dim cap for
mulandpermuteso they run at any numel.mul.Tensorandpermutestill usedcompute_1d_workgroup_count, which throws oncenumel / wg_size > 65535— hit by a realistic Llama-3.2-1B LoRA layer (mulover[2048, 8192]= 262k workgroups;permuteof[2048, 2048]= 65536).add/sub/div/fill/sdpaalready use the 2D fold; this bringsmul+permutein line.Key changes:
mul/BinaryOp.cpp,permute/Permute.cpp—compute_1d_workgroup_count→compute_2d_workgroup_count(returnsutils::WgCount); dispatch + resize hook now set bothworkgroup_count_xandworkgroup_count_y.binary_mul.wgsl,permute.wgsl—maintakes@builtin(num_workgroups); flat indexgid.x + gid.y * (num_workgroups.x * wg_size)(regenerated*_wgsl.h).Mirrors the landed
addop fold (runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}).Co-authored-with: Claude Code.
@exported-using-ghexport
Differential Revision: D110149677
Differential Revision: D110149677