Skip to content

Add MLX backend support for Gemma 4 31B#19524

Draft
mergennachin wants to merge 3 commits into
mainfrom
gemma4_mlx
Draft

Add MLX backend support for Gemma 4 31B#19524
mergennachin wants to merge 3 commits into
mainfrom
gemma4_mlx

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented May 12, 2026

Adds Apple Silicon (MLX) backend for the Gemma 4 31B-IT model. The same quantized checkpoint works for both CUDA and MLX — backend-specific packing happens at load time.

Key changes:

  • MLX packer converts Int4Tensor → IntxUnpackedToInt8Tensor for MLX's quantized linear fusion
  • Source transforms replace PyTorch ops with mlx.rope, mlx.kv_cache_update, mlx.custom_sdpa for optimized Metal kernels
  • Single-method export with dynamic seq_len and host-side sampling
  • C++ runner supports both backends via #ifdef, using shared logits_to_token for MLX sampling
  • Last-logits-only optimization: lm_head always runs on last position only, removing the full-logits codepath entirely

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19524

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 Unclassified Failures

As of commit 6423b4b with merge base 3ceb89c (image):

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 12, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Comment thread examples/models/gemma4_31b/main.cpp Outdated
Comment thread examples/models/gemma4_31b/quant/pack_mlx.py Outdated
Comment thread examples/models/gemma4_31b/quant/pack_mlx.py
mergennachin and others added 2 commits May 18, 2026 08:28
- pack_mlx.py: converts Int4Tensor → IntxUnpackedToInt8Tensor at pack
  time (nibble unpack + scale transpose) so the default dispatch
  produces the dequantize_affine → linear pattern MLX expects.
  IntxUnpackedToInt8Tensor passes through unchanged. Embedding with
  incompatible per-axis group_size is regrouped to gs=128.
- export.py: add --backend mlx with single-method export (dynamic
  seq_len), sampler stripping, and MLXPartitioner lowering. No
  int4_dispatch import — MLX uses the standard dequantize_affine path.
- main.cpp: handle both CUDA (prefill+decode, on-device sampling) and
  MLX (single forward method, host-side argmax) via #ifdef.
- CMakeLists.txt / CMakePresets.json / Makefile: add gemma4_31b-mlx
  build target linking mlxdelegate.
- test_pack_mlx.py: 15 tests covering Int4→IntxUnpacked conversion
  correctness, passthrough, regrouping, error cases.
- test_mlx_pipeline.py: 4 e2e tests including export-to-pte.

Validated: same CUDA-quantized checkpoint packs for both backends,
100% op delegation to MLX, real 31B checkpoint packs at 4.0 GB RSS.

PR authored with Claude.
…LX packer

- Replace custom argmax_last_token with llm::logits_to_token for
  host-side sampling on non-CUDA builds, matching qwen3_5_moe runner.
  Supports temperature-controlled sampling (was greedy-only).
- Add --cuda_graph warning on non-CUDA builds.
- Support Int4Tensor embeddings in pack_embedding_for_mlx by converting
  to IntxUnpackedToInt8Tensor (same as linear path).
- Add divisibility guard in _regroup_intx.

Co-authored-by: Claude <noreply@anthropic.com>
… temp handling

- Always compute logits for the last position only (lm_head on x[:, -1, :]),
  avoiding the (1, T, 262144) matmul during prefill. Applies to both CUDA
  and MLX paths.
- Remove the temperature=None codepath from model.py forward and sampler.py.
  Temperature is now always required. MLX _strip_sampler_from_forward handles
  the no-sampler case independently.
- Add mlx_source_transformations.py: replaces generic PyTorch ops with
  mlx.rope, mlx.kv_cache_update, and mlx.custom_sdpa for optimized Metal
  kernels. Applied during MLX export before torch.export.
- Unify temperature clamping in main.cpp: compute temp_val once before the
  #ifdef, used by both CUDA (temp_tensor) and MLX (logits_to_token).
- Fix generate() default temperature to 0.8 (was 0.0, inconsistent with C++).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants