Add MLX backend support for Gemma 4 31B#19524
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19524
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 Unclassified FailuresAs of commit 6423b4b with merge base 3ceb89c ( UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
- pack_mlx.py: converts Int4Tensor → IntxUnpackedToInt8Tensor at pack time (nibble unpack + scale transpose) so the default dispatch produces the dequantize_affine → linear pattern MLX expects. IntxUnpackedToInt8Tensor passes through unchanged. Embedding with incompatible per-axis group_size is regrouped to gs=128. - export.py: add --backend mlx with single-method export (dynamic seq_len), sampler stripping, and MLXPartitioner lowering. No int4_dispatch import — MLX uses the standard dequantize_affine path. - main.cpp: handle both CUDA (prefill+decode, on-device sampling) and MLX (single forward method, host-side argmax) via #ifdef. - CMakeLists.txt / CMakePresets.json / Makefile: add gemma4_31b-mlx build target linking mlxdelegate. - test_pack_mlx.py: 15 tests covering Int4→IntxUnpacked conversion correctness, passthrough, regrouping, error cases. - test_mlx_pipeline.py: 4 e2e tests including export-to-pte. Validated: same CUDA-quantized checkpoint packs for both backends, 100% op delegation to MLX, real 31B checkpoint packs at 4.0 GB RSS. PR authored with Claude.
…LX packer - Replace custom argmax_last_token with llm::logits_to_token for host-side sampling on non-CUDA builds, matching qwen3_5_moe runner. Supports temperature-controlled sampling (was greedy-only). - Add --cuda_graph warning on non-CUDA builds. - Support Int4Tensor embeddings in pack_embedding_for_mlx by converting to IntxUnpackedToInt8Tensor (same as linear path). - Add divisibility guard in _regroup_intx. Co-authored-by: Claude <noreply@anthropic.com>
… temp handling - Always compute logits for the last position only (lm_head on x[:, -1, :]), avoiding the (1, T, 262144) matmul during prefill. Applies to both CUDA and MLX paths. - Remove the temperature=None codepath from model.py forward and sampler.py. Temperature is now always required. MLX _strip_sampler_from_forward handles the no-sampler case independently. - Add mlx_source_transformations.py: replaces generic PyTorch ops with mlx.rope, mlx.kv_cache_update, and mlx.custom_sdpa for optimized Metal kernels. Applied during MLX export before torch.export. - Unify temperature clamping in main.cpp: compute temp_val once before the #ifdef, used by both CUDA (temp_tensor) and MLX (logits_to_token). - Fix generate() default temperature to 0.8 (was 0.0, inconsistent with C++).
Adds Apple Silicon (MLX) backend for the Gemma 4 31B-IT model. The same quantized checkpoint works for both CUDA and MLX — backend-specific packing happens at load time.
Key changes: