Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ jobs:

# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
pip install gguf
python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts="
python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ --ignore=examples/models/gemma4_31b/tests/test_mlx_pipeline.py -v -o "addopts="

export-model-cuda-artifact:
name: export-model-cuda-artifact
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
backends/mlx/test/test_passes.py \
backends/mlx/test/test_pattern_utils.py \
backends/mlx/test/test_partitioner.py \
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
-v
echo "::endgroup::"

Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand Down Expand Up @@ -127,6 +127,7 @@ help:
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
@echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend"
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
@echo " clean - Clean build artifacts"
Expand Down Expand Up @@ -435,6 +436,15 @@ gemma4_31b-cuda:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"

gemma4_31b-mlx:
@echo "==> Building and installing ExecuTorch with MLX..."
cmake --workflow --preset mlx-release
@echo "==> Building Gemma 4 31B runner with MLX..."
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-mlx
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"

qwen3_5_moe-metal:
@echo "==> Building and installing ExecuTorch with Metal..."
cmake --workflow --preset llm-release-metal
Expand Down
12 changes: 10 additions & 2 deletions backends/mlx/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,16 @@ def rope(
# final angles: [1, 1, T, half]
angles = (pos_range * inv_freq) * float(scale)
else:
# assume freqs is already per-position, just reshape to [1,1,T,half]
angles = freqs.to(torch.float32).view(1, 1, T, half)
if freqs.ndim == 1:
# 1D raw frequencies: compute angles = positions * (1/freqs)
inv_freq = (1.0 / freqs.to(torch.float32)).view(1, 1, 1, half)
pos_range = torch.arange(
pos, pos + T, device=x.device, dtype=torch.float32
).view(1, 1, T, 1)
angles = (pos_range * inv_freq) * float(scale)
else:
# 2D per-position angles: reshape to [1,1,T,half]
angles = freqs.to(torch.float32).view(1, 1, T, half)

cos = angles.cos().to(x.dtype) # [1,1,T,half]
sin = angles.sin().to(x.dtype) # [1,1,T,half]
Expand Down
9 changes: 7 additions & 2 deletions backends/mlx/runtime/MLXInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) {
freqs_arr = st.const_tensor_ref(*n.freqs);
}

// MLX requires exactly one of base or freqs — when freqs is provided,
// base must be nullopt.
std::optional<float> base =
freqs_arr ? std::nullopt : std::optional<float>(n.base);

// MLX has two overloads: rope(..., int offset, ...) and rope(..., const
// array& offset, ...) Call the appropriate one based on is_vid
if (n.offset.is_vid) {
Expand All @@ -250,14 +255,14 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) {
st.set_tensor(
n.out,
fast::rope(
x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s));
x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s));
} else {
// Tensor offset from Tid
const array& offset = st.const_tensor_ref(n.offset.tid);
st.set_tensor(
n.out,
fast::rope(
x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s));
x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s));
}
}

Expand Down
76 changes: 76 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,6 +1803,82 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (q, k, pos_tensor)


class RopeCustomFreqsModel(nn.Module):
"""Model that applies RoPE with custom 1D frequencies (partial rotary)."""

def __init__(self, dims: int = 32, head_dim: int = 64):
super().__init__()
self.dims = dims
self.head_dim = head_dim
# Simulate proportional RoPE: compute freqs for rotary dims only
inv_freq = 1.0 / (
500000.0 ** (torch.arange(0, dims, 2, dtype=torch.float32) / head_dim)
)
self.register_buffer("freqs", 1.0 / inv_freq, persistent=False)

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
pos_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
pos = pos_tensor.item()
q_rot = torch.ops.mlx.rope(q, self.dims, pos, False, 0.0, 1.0, self.freqs)
k_rot = torch.ops.mlx.rope(k, self.dims, pos, False, 0.0, 1.0, self.freqs)
return q_rot, k_rot


@register_test
class RopeCustomFreqsTest(OpTestCase):
"""Test RoPE with custom 1D frequencies (partial rotary, like Gemma 4)."""

name = "rope_custom_freqs"
rtol = 1e-4
atol = 1e-4

def __init__(
self,
batch_size: int = 1,
num_heads: int = 8,
seq_len: int = 4,
head_dim: int = 64,
dims: int = 32,
pos: int = 0,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.seq_len = seq_len
self.head_dim = head_dim
self.dims = dims
self.pos = pos
self.name = "rope_custom_freqs"

@classmethod
def get_test_configs(cls) -> List["RopeCustomFreqsTest"]:
configs = [
cls(),
cls(pos=10),
cls(head_dim=128, dims=64),
]
for cfg in configs:
parts = ["rope_custom_freqs"]
if cfg.pos > 0:
parts.append(f"pos{cfg.pos}")
if cfg.head_dim != 64:
parts.append(f"hd{cfg.head_dim}")
cfg.name = "_".join(parts)
return configs

def create_model(self) -> nn.Module:
return RopeCustomFreqsModel(dims=self.dims, head_dim=self.head_dim)

def create_inputs(self) -> Tuple[torch.Tensor, ...]:
q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim)
k = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim)
pos_tensor = torch.tensor(self.pos, dtype=torch.int64)
return (q, k, pos_tensor)


from executorch.backends.mlx.llm.cache import KVCache


Expand Down
15 changes: 12 additions & 3 deletions examples/models/gemma4_31b/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ list(
extension_flat_tensor
)

# CUDA backend (the only supported backend for this example for now)
# Backend: CUDA or MLX (exactly one required)
if(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit REQUIRED)
list(APPEND link_libraries aoti_cuda_backend)
executorch_target_link_options_shared_lib(aoti_cuda_backend)
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
elseif(TARGET mlxdelegate)
list(APPEND link_libraries mlxdelegate mlx)
executorch_target_link_options_shared_lib(mlxdelegate)
else()
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON")
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_MLX=ON")
endif()

# Tokenizer (HuggingFace tokenizer.json)
Expand All @@ -63,5 +66,11 @@ target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries})

if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
target_link_options_gc_sections(gemma4_31b_runner)
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
if(NOT APPLE AND NOT MSVC)
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
endif()
endif()

if(TARGET mlxdelegate)
executorch_target_copy_mlx_metallib(gemma4_31b_runner)
endif()
31 changes: 31 additions & 0 deletions examples/models/gemma4_31b/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
"string": "${hostSystemName}",
"list": ["Linux", "Windows"]
}
},
{
"name": "gemma4-31b-mlx",
"displayName": "Gemma 4 31B runner (MLX)",
"inherits": ["gemma4-31b-base"],
"cacheVariables": {},
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Darwin"
}
}
],
"buildPresets": [
Expand All @@ -31,6 +42,12 @@
"displayName": "Build Gemma 4 31B runner (CUDA)",
"configurePreset": "gemma4-31b-cuda",
"targets": ["gemma4_31b_runner"]
},
{
"name": "gemma4-31b-mlx",
"displayName": "Build Gemma 4 31B runner (MLX)",
"configurePreset": "gemma4-31b-mlx",
"targets": ["gemma4_31b_runner"]
}
],
"workflowPresets": [
Expand All @@ -47,6 +64,20 @@
"name": "gemma4-31b-cuda"
}
]
},
{
"name": "gemma4-31b-mlx",
"displayName": "Configure and build Gemma 4 31B runner (MLX)",
"steps": [
{
"type": "configure",
"name": "gemma4-31b-mlx"
},
{
"type": "build",
"name": "gemma4-31b-mlx"
}
]
}
]
}
22 changes: 19 additions & 3 deletions examples/models/gemma4_31b/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Gemma 4 31B-IT

Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8
weight quantization. Currently supports the CUDA backend.
weight quantization. Supports CUDA and MLX (Apple Silicon) backends.

For architecture and design notes see [model.md](model.md).

Expand Down Expand Up @@ -67,6 +67,8 @@ recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into

## Export to ExecuTorch

### CUDA

```bash
python examples/models/gemma4_31b/export.py \
--prequantized ./gemma4_31b_int4 \
Expand All @@ -75,7 +77,20 @@ python examples/models/gemma4_31b/export.py \
--backend cuda
```

Writes `model.pte` and `model.ptd` into `--output-dir`.
### MLX (Apple Silicon)

```bash
python examples/models/gemma4_31b/export.py \
--prequantized ./gemma4_31b_int4 \
--output-dir ./gemma4_31b_exports_mlx \
--max-seq-len 4096 \
--backend mlx
```

The same quantized checkpoint works for both backends. MLX exports a single
method with dynamic sequence length and host-side sampling.

Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`.

## Eager inference

Expand Down Expand Up @@ -105,7 +120,8 @@ model produces sensible text.
## Build the runner

```bash
make gemma4_31b-cuda
make gemma4_31b-cuda # Linux — CUDA backend
make gemma4_31b-mlx # macOS — MLX backend (Apple Silicon)
```

The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.
Expand Down
Loading
Loading