You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[ExecuTorch][WebGPU] 2D-fold mul + permute dispatch (lift 65535 1D cap)
Pull Request resolved: #20651
**Lift the 65535 workgroup-per-dim cap for `mul` and `permute` so they run at any numel.**
`mul.Tensor` and `permute` still used `compute_1d_workgroup_count`, which throws once `numel / wg_size > 65535` — hit by a realistic Llama-3.2-1B LoRA layer (`mul` over `[2048, 8192]` = 262k workgroups; `permute` of `[2048, 2048]` = 65536). `add`/`sub`/`div`/`fill`/`sdpa` already use the 2D fold; this brings `mul` + `permute` in line.
Key changes:
- `mul/BinaryOp.cpp`, `permute/Permute.cpp` — `compute_1d_workgroup_count` → `compute_2d_workgroup_count` (returns `utils::WgCount`); dispatch + resize hook now set both `workgroup_count_x` and `workgroup_count_y`.
- `binary_mul.wgsl`, `permute.wgsl` — `main` takes `@builtin(num_workgroups)`; flat index `gid.x + gid.y * (num_workgroups.x * wg_size)` (regenerated `*_wgsl.h`).
Mirrors the landed `add` op fold (`runtime/ops/add/{BinaryOp.cpp,binary_add.wgsl}`).
Co-authored-with: Claude Code.
ghstack-source-id: 399812930
@exported-using-ghexport
Differential Revision: [D110149677](https://our.internmc.facebook.com/intern/diff/D110149677/)
0 commit comments