diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index eb7fc5a8939..e1eaba6b7c1 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -150,7 +150,7 @@ jobs: # Run Gemma 4 31B tests (quant unit tests + pipeline integration tests) pip install gguf - python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts=" + python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ --ignore=examples/models/gemma4_31b/tests/test_mlx_pipeline.py -v -o "addopts=" export-model-cuda-artifact: name: export-model-cuda-artifact diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 4778d08fcdc..d429db16053 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -60,6 +60,7 @@ jobs: backends/mlx/test/test_passes.py \ backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ + examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ -v echo "::endgroup::" diff --git a/Makefile b/Makefile index ba61dddce44..9b7f24b2f83 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -127,6 +127,7 @@ help: @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" + @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -435,6 +436,15 @@ gemma4_31b-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" +gemma4_31b-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building Gemma 4 31B runner with MLX..." + cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index d7d6288ba8f..c03db05d918 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -228,8 +228,16 @@ def rope( # final angles: [1, 1, T, half] angles = (pos_range * inv_freq) * float(scale) else: - # assume freqs is already per-position, just reshape to [1,1,T,half] - angles = freqs.to(torch.float32).view(1, 1, T, half) + if freqs.ndim == 1: + # 1D raw frequencies: compute angles = positions * (1/freqs) + inv_freq = (1.0 / freqs.to(torch.float32)).view(1, 1, 1, half) + pos_range = torch.arange( + pos, pos + T, device=x.device, dtype=torch.float32 + ).view(1, 1, T, 1) + angles = (pos_range * inv_freq) * float(scale) + else: + # 2D per-position angles: reshape to [1,1,T,half] + angles = freqs.to(torch.float32).view(1, 1, T, half) cos = angles.cos().to(x.dtype) # [1,1,T,half] sin = angles.sin().to(x.dtype) # [1,1,T,half] diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 1f961459d22..fb6597d171e 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -242,6 +242,11 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) { freqs_arr = st.const_tensor_ref(*n.freqs); } + // MLX requires exactly one of base or freqs — when freqs is provided, + // base must be nullopt. + std::optional base = + freqs_arr ? std::nullopt : std::optional(n.base); + // MLX has two overloads: rope(..., int offset, ...) and rope(..., const // array& offset, ...) Call the appropriate one based on is_vid if (n.offset.is_vid) { @@ -250,14 +255,14 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) { st.set_tensor( n.out, fast::rope( - x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s)); + x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s)); } else { // Tensor offset from Tid const array& offset = st.const_tensor_ref(n.offset.tid); st.set_tensor( n.out, fast::rope( - x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s)); + x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s)); } } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index afc45adcc93..4471610519e 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -1803,6 +1803,82 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (q, k, pos_tensor) +class RopeCustomFreqsModel(nn.Module): + """Model that applies RoPE with custom 1D frequencies (partial rotary).""" + + def __init__(self, dims: int = 32, head_dim: int = 64): + super().__init__() + self.dims = dims + self.head_dim = head_dim + # Simulate proportional RoPE: compute freqs for rotary dims only + inv_freq = 1.0 / ( + 500000.0 ** (torch.arange(0, dims, 2, dtype=torch.float32) / head_dim) + ) + self.register_buffer("freqs", 1.0 / inv_freq, persistent=False) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + pos_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + pos = pos_tensor.item() + q_rot = torch.ops.mlx.rope(q, self.dims, pos, False, 0.0, 1.0, self.freqs) + k_rot = torch.ops.mlx.rope(k, self.dims, pos, False, 0.0, 1.0, self.freqs) + return q_rot, k_rot + + +@register_test +class RopeCustomFreqsTest(OpTestCase): + """Test RoPE with custom 1D frequencies (partial rotary, like Gemma 4).""" + + name = "rope_custom_freqs" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 1, + num_heads: int = 8, + seq_len: int = 4, + head_dim: int = 64, + dims: int = 32, + pos: int = 0, + ): + self.batch_size = batch_size + self.num_heads = num_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.dims = dims + self.pos = pos + self.name = "rope_custom_freqs" + + @classmethod + def get_test_configs(cls) -> List["RopeCustomFreqsTest"]: + configs = [ + cls(), + cls(pos=10), + cls(head_dim=128, dims=64), + ] + for cfg in configs: + parts = ["rope_custom_freqs"] + if cfg.pos > 0: + parts.append(f"pos{cfg.pos}") + if cfg.head_dim != 64: + parts.append(f"hd{cfg.head_dim}") + cfg.name = "_".join(parts) + return configs + + def create_model(self) -> nn.Module: + return RopeCustomFreqsModel(dims=self.dims, head_dim=self.head_dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + k = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + pos_tensor = torch.tensor(self.pos, dtype=torch.int64) + return (q, k, pos_tensor) + + from executorch.backends.mlx.llm.cache import KVCache diff --git a/examples/models/gemma4_31b/CMakeLists.txt b/examples/models/gemma4_31b/CMakeLists.txt index 8d536a47fc5..52419eb95bc 100644 --- a/examples/models/gemma4_31b/CMakeLists.txt +++ b/examples/models/gemma4_31b/CMakeLists.txt @@ -42,14 +42,17 @@ list( extension_flat_tensor ) -# CUDA backend (the only supported backend for this example for now) +# Backend: CUDA or MLX (exactly one required) if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit REQUIRED) list(APPEND link_libraries aoti_cuda_backend) executorch_target_link_options_shared_lib(aoti_cuda_backend) add_compile_definitions(EXECUTORCH_BUILD_CUDA) +elseif(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) else() - message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON") + message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_MLX=ON") endif() # Tokenizer (HuggingFace tokenizer.json) @@ -63,5 +66,11 @@ target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries}) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(gemma4_31b_runner) - target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s") + if(NOT APPLE AND NOT MSVC) + target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s") + endif() +endif() + +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(gemma4_31b_runner) endif() diff --git a/examples/models/gemma4_31b/CMakePresets.json b/examples/models/gemma4_31b/CMakePresets.json index 97ba7f4c57a..23a7d42e035 100644 --- a/examples/models/gemma4_31b/CMakePresets.json +++ b/examples/models/gemma4_31b/CMakePresets.json @@ -23,6 +23,17 @@ "string": "${hostSystemName}", "list": ["Linux", "Windows"] } + }, + { + "name": "gemma4-31b-mlx", + "displayName": "Gemma 4 31B runner (MLX)", + "inherits": ["gemma4-31b-base"], + "cacheVariables": {}, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -31,6 +42,12 @@ "displayName": "Build Gemma 4 31B runner (CUDA)", "configurePreset": "gemma4-31b-cuda", "targets": ["gemma4_31b_runner"] + }, + { + "name": "gemma4-31b-mlx", + "displayName": "Build Gemma 4 31B runner (MLX)", + "configurePreset": "gemma4-31b-mlx", + "targets": ["gemma4_31b_runner"] } ], "workflowPresets": [ @@ -47,6 +64,20 @@ "name": "gemma4-31b-cuda" } ] + }, + { + "name": "gemma4-31b-mlx", + "displayName": "Configure and build Gemma 4 31B runner (MLX)", + "steps": [ + { + "type": "configure", + "name": "gemma4-31b-mlx" + }, + { + "type": "build", + "name": "gemma4-31b-mlx" + } + ] } ] } diff --git a/examples/models/gemma4_31b/README.md b/examples/models/gemma4_31b/README.md index 94783c8f823..da4aa893079 100644 --- a/examples/models/gemma4_31b/README.md +++ b/examples/models/gemma4_31b/README.md @@ -1,7 +1,7 @@ # Gemma 4 31B-IT Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8 -weight quantization. Currently supports the CUDA backend. +weight quantization. Supports CUDA and MLX (Apple Silicon) backends. For architecture and design notes see [model.md](model.md). @@ -67,6 +67,8 @@ recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into ## Export to ExecuTorch +### CUDA + ```bash python examples/models/gemma4_31b/export.py \ --prequantized ./gemma4_31b_int4 \ @@ -75,7 +77,20 @@ python examples/models/gemma4_31b/export.py \ --backend cuda ``` -Writes `model.pte` and `model.ptd` into `--output-dir`. +### MLX (Apple Silicon) + +```bash +python examples/models/gemma4_31b/export.py \ + --prequantized ./gemma4_31b_int4 \ + --output-dir ./gemma4_31b_exports_mlx \ + --max-seq-len 4096 \ + --backend mlx +``` + +The same quantized checkpoint works for both backends. MLX exports a single +method with dynamic sequence length and host-side sampling. + +Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`. ## Eager inference @@ -105,7 +120,8 @@ model produces sensible text. ## Build the runner ```bash -make gemma4_31b-cuda +make gemma4_31b-cuda # Linux — CUDA backend +make gemma4_31b-mlx # macOS — MLX backend (Apple Silicon) ``` The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`. diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index a96dba0d512..046e365947b 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -19,6 +19,8 @@ Backends: --backend cuda (default) CUDA via tinygemm INT4 + CudaPartitioner. + --backend mlx Apple Silicon via MLXPartitioner (single method, + dynamic seq_len, host-side sampling). """ import argparse @@ -98,12 +100,21 @@ def load_and_quantize( # Backend dispatch helpers +_SUPPORTED_BACKENDS = ("cuda", "mlx") + + def _get_packers(backend: str) -> dict: if backend == "cuda": from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS return DEFAULT_CUDA_PACKERS - raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + if backend == "mlx": + from executorch.examples.models.gemma4_31b.quant import DEFAULT_MLX_PACKERS + + return DEFAULT_MLX_PACKERS + raise ValueError( + f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." + ) def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: @@ -111,8 +122,14 @@ def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda load_and_pack_for_cuda(path, model) + elif backend == "mlx": + from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_mlx + + load_and_pack_for_mlx(path, model) else: - raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + raise ValueError( + f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." + ) # --------------------------------------------------------------------------- @@ -128,8 +145,12 @@ def export_and_lower( """Export and lower the model to ExecuTorch for the given backend.""" if backend == "cuda": _export_cuda(model, config, output_dir) + elif backend == "mlx": + _export_mlx(model, config, output_dir) else: - raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.") + raise ValueError( + f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}." + ) def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: @@ -258,6 +279,98 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) - print("Done.") +def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None: + """Export to .pte via torch.export + MLX backend. + + Unlike CUDA (which exports separate decode/prefill methods with an + Int4Tensor dispatch override), MLX uses a single method with dynamic + sequence length. No int4_dispatch import — IntxUnpackedToInt8Tensor's + default dispatch produces the ``dequantize_affine → linear`` pattern + that MLX's QuantizedLinearHandler matches. + """ + import gc + + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + + from executorch.examples.models.gemma4_31b.mlx_source_transformations import ( + mlx_source_transformations, + ) + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, + ) + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + mlx_source_transformations(model, dtype=torch.bfloat16) + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2) + seq_dim = Dim("seq_len", min=1, max=max_prefill) + + print(f"Exporting (T in [1, {max_prefill}])...") + with torch.no_grad(): + exported = export( + model, + ( + torch.tensor([[0, 1]], dtype=torch.long), + torch.tensor([0, 1], dtype=torch.long), + ), + dynamic_shapes=({1: seq_dim}, {0: seq_dim}), + strict=True, + ) + + del model + gc.collect() + + print("Lowering to ExecuTorch with MLX backend...") + et_prog = to_edge_transform_and_lower( + exported, + transform_passes=get_default_passes(), + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods={ + "get_max_seq_len": config.max_seq_len, + "get_vocab_size": config.vocab_size, + "get_n_layers": config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": False, + "enable_dynamic_shape": True, + }, + ) + + del exported + gc.collect() + + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + del et_prog + gc.collect() + + os.makedirs(output_dir, exist_ok=True) + pte_path = os.path.join(output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB") + + if et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + print(f" Saved tensor data (.ptd) to {output_dir}/") + print("Done.") + + # --------------------------------------------------------------------------- # CLI @@ -302,7 +415,7 @@ def main() -> None: parser.add_argument( "--backend", default="cuda", - choices=["cuda"], + choices=list(_SUPPORTED_BACKENDS), help="Target backend for export.", ) args = parser.parse_args() diff --git a/examples/models/gemma4_31b/inference.py b/examples/models/gemma4_31b/inference.py index e1563c04ff6..92654fca5f2 100644 --- a/examples/models/gemma4_31b/inference.py +++ b/examples/models/gemma4_31b/inference.py @@ -86,7 +86,7 @@ def generate( tokenizer, prompt: str, max_new_tokens: int = 128, - temperature: float = 0.0, + temperature: float = 0.8, eos_token_ids=None, bos_token_id: int = 2, ) -> str: diff --git a/examples/models/gemma4_31b/main.cpp b/examples/models/gemma4_31b/main.cpp index 3ddf64e410f..6cf65cc8246 100644 --- a/examples/models/gemma4_31b/main.cpp +++ b/examples/models/gemma4_31b/main.cpp @@ -6,18 +6,21 @@ * LICENSE file in the root directory of this source tree. */ -// Gemma 4 31B-IT runner for the CUDA ExecuTorch backend. -// -// Drives the prefill + decode methods produced by export.py. -// The exported model performs Gumbel-max sampling on-device and returns a -// single float token ID per call, so this runner only has to feed tokens -// in and decode them via the HuggingFace tokenizer. +// Gemma 4 31B-IT runner for ExecuTorch. Supports two backends: +// CUDA — exports ``prefill`` (T>=2, dynamic) + ``decode`` (T=1, static) +// methods sharing KV-cache buffers; on-device Gumbel-max sampling +// with temperature passed as a third input; returns a scalar +// float token id. +// MLX — exports a single ``forward`` method with dynamic seq_len; +// returns last-token logits; the runner samples on the host via +// ``llm::logits_to_token`` with the same temperature semantics. #include #include #include #include +#include #include #include #include @@ -82,6 +85,7 @@ using ::executorch::runtime::EValue; using SizesType = executorch::aten::SizesType; +// Read a sampled token ID from a scalar float output (CUDA path). static uint64_t read_token(const executorch::aten::Tensor& output) { const void* ptr = output.const_data_ptr(); float val = 0.0f; @@ -143,8 +147,7 @@ int main(int argc, char** argv) { return 1; } - // Module: share_memory_arenas=true so prefill and decode see the same - // KV-cache memory (we exported with share_mutable_buffers=True). + // Module std::vector data_files; if (!FLAGS_data_path.empty()) { data_files.push_back(FLAGS_data_path); @@ -152,7 +155,7 @@ int main(int argc, char** argv) { auto module = std::make_unique( FLAGS_model_path, data_files, - Module::LoadMode::File, + Module::LoadMode::MmapUseMlockIgnoreErrors, /*event_tracer=*/nullptr, /*memory_allocator=*/nullptr, /*temp_allocator=*/nullptr, @@ -165,6 +168,19 @@ int main(int argc, char** argv) { return 1; } + int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; + { + auto get_result = module->get("get_max_prefill_chunk"); + if (get_result.ok()) { + max_prefill_chunk = get_result->toScalar().to(); + } + } + + auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + + float temp_val = + FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); + #ifdef EXECUTORCH_BUILD_CUDA if (FLAGS_cuda_graph) { executorch::runtime::BackendOptions<2> cuda_opts; @@ -172,11 +188,6 @@ int main(int argc, char** argv) { executorch::runtime::set_option("CudaBackend", cuda_opts.view()); printf("CUDA graph enabled for decode method\n"); } - - // Cross-method per-FQN weight sharing: prefill + decode share the same - // weight tensors and (more importantly) the same KV-cache buffers, so - // without this flag we would allocate them twice. MUST be set before - // load_method. { executorch::runtime::BackendOptions<1> backend_options; auto set_err = @@ -184,7 +195,7 @@ int main(int argc, char** argv) { if (set_err != Error::Ok) { ET_LOG( Error, - "Failed to construct weight_sharing_across_methods option: %d", + "Failed to set weight_sharing_across_methods: %d", static_cast(set_err)); return 1; } @@ -198,12 +209,6 @@ int main(int argc, char** argv) { return 1; } } -#else - if (FLAGS_cuda_graph) { - ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); - } -#endif - printf("Loading methods...\n"); if (module->load_method("prefill") != Error::Ok) { ET_LOG(Error, "Failed to load prefill method"); @@ -213,6 +218,19 @@ int main(int argc, char** argv) { ET_LOG(Error, "Failed to load decode method"); return 1; } + auto temp_tensor = + from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); +#else + if (FLAGS_cuda_graph) { + ET_LOG(Info, "--cuda_graph ignored on non-CUDA build"); + } + printf("Loading model...\n"); + if (module->load_method("forward") != Error::Ok) { + ET_LOG(Error, "Failed to load forward method"); + return 1; + } +#endif + stats.model_load_end_ms = llm::time_in_ms(); #ifdef EXECUTORCH_BUILD_CUDA @@ -222,8 +240,12 @@ int main(int argc, char** argv) { auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); eos_ids.insert(static_cast(FLAGS_eos_id)); + auto turn_ids = tokenizer->encode("", /*bos=*/0, /*eos=*/0); + if (turn_ids.ok() && turn_ids->size() == 1) { + eos_ids.insert(turn_ids.get()[0]); + } - // Read prompt from file or flag + // Read prompt std::string prompt_text = FLAGS_prompt; if (!FLAGS_prompt_file.empty()) { std::ifstream f(FLAGS_prompt_file); @@ -260,38 +282,15 @@ int main(int argc, char** argv) { stats.inference_start_ms = llm::time_in_ms(); - auto S = [](int64_t v) -> SizesType { return static_cast(v); }; - -#ifdef EXECUTORCH_BUILD_CUDA - // CUDA build: model fuses the sampler. Pass temperature as a third input. - float temp_val = - FLAGS_temperature <= 0.0 ? 1e-6f : static_cast(FLAGS_temperature); - auto temp_tensor = - from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); -#endif - // --------------------------------------------------------------- // Prefill (chunked to respect ring-buffer KV cache limit) // --------------------------------------------------------------- - // Sliding layers use a ring buffer sized to 2×sliding_window. A single - // prefill call must not exceed this size, otherwise index_copy_ with - // wrapped indices produces non-deterministic results on CUDA. - int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1; - { - auto get_result = module->get("get_max_prefill_chunk"); - if (get_result.ok()) { - max_prefill_chunk = get_result->toScalar().to(); - } - } - uint64_t cur_token = 0; int64_t prefill_pos = 0; while (prefill_pos < num_prompt_tokens) { int64_t chunk_len = std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk); - std::string run_method = (chunk_len == 1) ? "decode" : "prefill"; - std::vector token_data( prompt_tokens.begin() + prefill_pos, prompt_tokens.begin() + prefill_pos + chunk_len); @@ -306,39 +305,52 @@ int main(int argc, char** argv) { auto pos_tensor = from_blob( pos_data.data(), {S(chunk_len)}, executorch::aten::ScalarType::Long); - std::vector prefill_inputs; - prefill_inputs.push_back(EValue(tokens_tensor)); - prefill_inputs.push_back(EValue(pos_tensor)); + std::vector inputs; + inputs.push_back(EValue(tokens_tensor)); + inputs.push_back(EValue(pos_tensor)); + #ifdef EXECUTORCH_BUILD_CUDA - prefill_inputs.push_back(EValue(temp_tensor)); + inputs.push_back(EValue(temp_tensor)); + std::string method = (chunk_len == 1) ? "decode" : "prefill"; +#else + std::string method = "forward"; #endif - auto prefill_result = module->execute(run_method, prefill_inputs); - if (prefill_result.error() != Error::Ok) { - ET_LOG( - Error, "%s failed at pos %" PRId64, run_method.c_str(), prefill_pos); + auto result = module->execute(method, inputs); + if (result.error() != Error::Ok) { + ET_LOG(Error, "%s failed at pos %" PRId64, method.c_str(), prefill_pos); return 1; } - cur_token = read_token(prefill_result.get()[0].toTensor()); + +#ifdef EXECUTORCH_BUILD_CUDA + cur_token = read_token(result.get()[0].toTensor()); +#else + cur_token = static_cast( + llm::logits_to_token(result.get()[0].toTensor(), temp_val)); +#endif + prefill_pos += chunk_len; } stats.prompt_eval_end_ms = llm::time_in_ms(); - double prefill_ms = - static_cast(stats.prompt_eval_end_ms - stats.inference_start_ms); - printf( - "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", - num_prompt_tokens, - prefill_ms, - num_prompt_tokens * 1000.0 / prefill_ms); + // First generated token came from the last prefill chunk; TTFT is prefill. + stats.first_token_ms = stats.prompt_eval_end_ms; #ifdef EXECUTORCH_BUILD_CUDA - // Synchronize CUDA device to ensure prefill's writes to shared mutable - // buffers (KV cache) are visible to the decode method, which may run on - // a different CUDA stream. cudaDeviceSynchronize(); #endif + // Print the first generated token (from the last prefill chunk). + // Use the last prompt token as the streaming-decode prefix so any BPE + // partial-character handling stays correct. + { + auto first_str = tokenizer->decode(prompt_tokens.back(), cur_token); + if (first_str.ok()) { + printf("%s", first_str->c_str()); + fflush(stdout); + } + } + // --------------------------------------------------------------- // Decode loop // --------------------------------------------------------------- @@ -351,29 +363,34 @@ int main(int argc, char** argv) { decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); uint64_t prev_token = cur_token; - for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + bool hit_eos = eos_ids.find(cur_token) != eos_ids.end(); + for (int32_t step = 0; step < FLAGS_max_new_tokens && !hit_eos; step++) { decode_token_data[0] = static_cast(cur_token); decode_pos_data[0] = pos; - std::vector decode_inputs; - decode_inputs.push_back(EValue(decode_tokens)); - decode_inputs.push_back(EValue(decode_pos)); + std::vector inputs; + inputs.push_back(EValue(decode_tokens)); + inputs.push_back(EValue(decode_pos)); + #ifdef EXECUTORCH_BUILD_CUDA - decode_inputs.push_back(EValue(temp_tensor)); + inputs.push_back(EValue(temp_tensor)); + auto result = module->execute("decode", inputs); +#else + auto result = module->execute("forward", inputs); #endif - auto decode_result = module->execute("decode", decode_inputs); - if (decode_result.error() != Error::Ok) { + if (result.error() != Error::Ok) { ET_LOG(Error, "Decode step %d failed", step); return 1; } prev_token = cur_token; - cur_token = read_token(decode_result.get()[0].toTensor()); - - if (step == 0) { - stats.first_token_ms = llm::time_in_ms(); - } +#ifdef EXECUTORCH_BUILD_CUDA + cur_token = read_token(result.get()[0].toTensor()); +#else + cur_token = static_cast( + llm::logits_to_token(result.get()[0].toTensor(), temp_val)); +#endif pos++; auto decode_str = tokenizer->decode(prev_token, cur_token); @@ -382,25 +399,12 @@ int main(int argc, char** argv) { fflush(stdout); } - if (eos_ids.find(cur_token) != eos_ids.end()) { - printf("\n"); - break; - } + hit_eos = eos_ids.find(cur_token) != eos_ids.end(); } - - stats.inference_end_ms = llm::time_in_ms(); printf("\n"); - int64_t num_generated = pos - num_prompt_tokens; - stats.num_generated_tokens = num_generated; - double decode_ms = - static_cast(stats.inference_end_ms - stats.prompt_eval_end_ms); - printf( - "Decode: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", - num_generated, - decode_ms, - num_generated * 1000.0 / decode_ms); - printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); + stats.inference_end_ms = llm::time_in_ms(); + stats.num_generated_tokens = pos - num_prompt_tokens; #ifdef EXECUTORCH_BUILD_CUDA cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); diff --git a/examples/models/gemma4_31b/mlx_source_transformations.py b/examples/models/gemma4_31b/mlx_source_transformations.py new file mode 100644 index 00000000000..3a8ae4420e3 --- /dev/null +++ b/examples/models/gemma4_31b/mlx_source_transformations.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""MLX source transformations for Gemma 4 31B-IT. + +Replaces the generic PyTorch ops in the model with MLX custom ops that lower +to optimized Metal kernels: + +- ``torch.ops.mlx.rope`` for rotary position embeddings +- ``torch.ops.mlx.kv_cache_update`` for KV cache scatter (via MLX cache modules) +- ``torch.ops.mlx.custom_sdpa`` for scaled dot-product attention with GQA + +Applied at export time before ``torch.export`` — the model code in ``model.py`` +stays backend-agnostic. +""" + +import executorch.backends.mlx.custom_ops # noqa: F401 — registers mlx:: ops +import torch +import torch.nn as nn +from executorch.backends.mlx.llm.cache import ( + KVCache as MLXKVCache, + RingBufferKVCache as MLXRingKVCache, +) + + +def _replace_attention_forward(attn: nn.Module) -> None: + """Replace a Gemma4Attention's forward with one that uses MLX custom ops.""" + import types + + def _mlx_forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + B, T, _ = x.shape + start_pos = input_pos[0].item() + + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) + raw_k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + if self.k_eq_v: + raw_v = raw_k + else: + raw_v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(raw_k) + v = self.v_norm(raw_v) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE via mlx::rope. + if self.is_sliding: + q = torch.ops.mlx.rope( + q, self.head_dim, start_pos, False, self.rope_theta, 1.0, None + ) + k = torch.ops.mlx.rope( + k, self.head_dim, start_pos, False, self.rope_theta, 1.0, None + ) + else: + # Full-attention layers use proportional partial RoPE: only + # rotary_dim out of head_dim dimensions are rotated. Pass + # dims=rotary_dim and the non-zero frequencies as 1D freqs. + # MLX computes inv_freq = 1/freqs internally. + rotary_dim = int(self.head_dim * self.partial_rotary) + rotary_inv_freq = self.inv_freq[: rotary_dim // 2] + mlx_freqs = 1.0 / rotary_inv_freq + q = torch.ops.mlx.rope(q, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs) + k = torch.ops.mlx.rope(k, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs) + + k_cache, v_cache = self.kv_cache.update(start_pos, k, v) + + if self.is_sliding: + sdpa_mask = self.kv_cache.create_sliding_window_mask(start_pos, T) + y = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=self.kv_cache.buffer_size - T, + attn_mask=sdpa_mask, + dropout_p=0.0, + is_causal=False, + scale=self.scaling, + ) + else: + y = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=start_pos, + dropout_p=0.0, + is_causal=True, + scale=self.scaling, + ) + + y = y.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + return self.o_proj(y) + + attn.forward = types.MethodType(_mlx_forward, attn) + + +def _replace_layer_forward(layer: nn.Module) -> None: + """Replace Gemma4DecoderLayer's forward to remove mask parameters.""" + import types + + def _mlx_layer_forward( + self, x: torch.Tensor, input_pos: torch.Tensor + ) -> torch.Tensor: + residual = x + h = self.input_layernorm(x) + h = self.self_attn(h, input_pos) + h = self.post_attention_layernorm(h) + x = residual + h + + residual = x + h = self.pre_feedforward_layernorm(x) + h = self.mlp(h) + h = self.post_feedforward_layernorm(h) + x = residual + h + + return x * self.layer_scalar + + layer.forward = types.MethodType(_mlx_layer_forward, layer) + + +def _replace_model_forward(model: nn.Module) -> None: + """Replace the top-level Gemma4_31B forward with a sampler-free, mask-free + ``(tokens, input_pos) → (B, 1, V)`` variant. + + MLX samples on the host, so the on-device sampler and temperature input + are dropped. Each MLX attention builds its own mask via ``custom_sdpa``, + so ``_build_masks`` and the per-layer mask arguments are removed. + """ + import types + + def _mlx_model_forward( + self, tokens: torch.Tensor, input_pos: torch.Tensor + ) -> torch.Tensor: + x = self.embed_tokens(tokens) * self.embed_normalizer + for layer in self.layers: + x = layer(x, input_pos) + x = self.norm(x) + last = self.lm_head(x[:, -1, :]).float() + cap = self.logit_softcap.float() + return torch.tanh(last / cap) * cap + + model.forward = types.MethodType(_mlx_model_forward, model) + + +def mlx_source_transformations( + model: nn.Module, + dtype: torch.dtype = torch.bfloat16, +) -> None: + """Apply MLX source transformations to a Gemma 4 31B model in-place. + + Self-contained MLX adaptation. After calling this, the model has + signature ``(tokens, input_pos) → (B, 1, V)`` logits — no temperature, + no sampler, no attention masks. + + - Replaces KV caches with MLX-optimized versions using ``mlx.kv_cache_update`` + - Rewrites attention forward to use ``mlx.rope`` and ``mlx.custom_sdpa`` + - Rewrites layer forward to drop mask parameters (each attention builds + its own mask via ``custom_sdpa``) + - Rewrites model forward to drop the sampler and ``_build_masks`` + """ + config = model.config + + for layer in model.layers: + attn = layer.self_attn + + if attn.is_sliding: + attn.kv_cache = MLXRingKVCache( + max_batch_size=1, + max_context_length=config.sliding_window, + n_heads=attn.n_kv_heads, + head_dim=attn.head_dim, + dtype=dtype, + ) + else: + attn.kv_cache = MLXKVCache( + max_batch_size=1, + max_context_length=config.max_seq_len, + n_heads=attn.n_kv_heads, + head_dim=attn.head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + + _replace_attention_forward(attn) + _replace_layer_forward(layer) + + _replace_model_forward(model) diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 51e420528f1..13207bdbb06 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -102,6 +102,8 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`, ## Methods exported (`export.py`) +### CUDA (`--backend cuda`) + | Method | Input | Output (sampled) | |-----------|------------------------------------------------------------|------------------| | `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float | @@ -113,6 +115,23 @@ Both methods share the same KV-cache buffers via sampling on-device and returns a single token ID per call so the C++ runner only has to feed tokens. +### MLX (`--backend mlx`) + +| Method | Input | Output | +|-----------|------------------------------------------|------------------| +| `forward` | tokens `(1, T)` + input_pos `(T,)`, T∈[1, min(max_seq_len-1, 2×sliding_window)] | `(1, V)` logits | + +Single method with dynamic sequence length. Only the last token's logits +are returned. The C++ runner samples on the host via `logits_to_token` +with temperature support. Int4Tensor weights are converted to +IntxUnpackedToInt8Tensor at pack time so the default `dequantize_affine → +linear` dispatch produces the pattern MLX's `QuantizedLinearHandler` fuses +into `QuantizedMatmulNode`. Source transforms (`mlx_source_transformations.py`) +replace generic PyTorch ops with `mlx.rope`, `mlx.kv_cache_update`, and +`mlx.custom_sdpa` for optimized Metal kernels. + +### Shared + Prefill length is capped to the ring-buffer KV cache size (`2 × sliding_window`) to avoid duplicate wrapped indices in `index_copy_`. The C++ runner chunks longer prompts automatically using @@ -130,9 +149,11 @@ Modules in `quant/`: `IntxUnpackedToInt8Tensor`) from fp weights. - **Serialization**: callers use torchao's safetensors integration (`torchao.prototype.safetensors`) directly — no wrapper module needed. -- **Pack** (`pack.py` + `pack_cuda.py`): `pack_model` groups weights by - parent module, `pack_one` handles single weights. Per-module packers - dispatch by module type (`nn.Linear`, `nn.Embedding`, extensible for MoE). +- **Pack** (`pack.py` + `pack_cuda.py` + `pack_mlx.py`): `pack_model` groups + weights by parent module, `pack_one` handles single weights. Per-module + packers dispatch by module type (`nn.Linear`, `nn.Embedding`). CUDA passes + Int4Tensor through (dispatch handled by `int4_dispatch.py`); MLX converts + Int4Tensor → IntxUnpackedToInt8Tensor and regroups per-axis embeddings. - **GGUF** (`gguf.py`): `unpack_gguf_tensor` / `iter_gguf_tensors` for loading community-quantized GGUF files (Q4_K, Q6_K). @@ -145,11 +166,12 @@ quantize_and_save.py export.py / inference.py | | quantize_weight() load (torchao safetensors) | | - Int4Tensor / IntxUnpacked Int4Tensor / IntxUnpacked (used directly) - | | - save (torchao safetensors) int4_dispatch routes to int4_plain_mm + Int4Tensor / IntxUnpacked pack for backend: | | - model.safetensors dp4a decode / dequant+cuBLAS prefill + save (torchao safetensors) CUDA: Int4Tensor passed through + | → int4_dispatch → dp4a / dequant+cuBLAS + model.safetensors MLX: Int4Tensor → IntxUnpacked(int4) + → dequantize_affine → QuantizedMatmulNode ``` `embed_tokens` and `lm_head` start tied; they are untied before @@ -181,14 +203,17 @@ These exist solely to make the model exportable / efficient under ExecuTorch: `2 × sliding_window`) saves memory for long sequences — positions wrap via modulo and the attention mask reconstructs which slots are valid. Full-attention layers use a flat `Gemma4KVCache` sized to `max_seq_len`. - Both use `index_copy_(dim=2, ...)` for trace-friendly updates. + CUDA uses `index_copy_` for trace-friendly updates; MLX source transforms + replace both caches with `mlx.kv_cache_update`-backed equivalents. - **On-the-fly RoPE**: stores only `inv_freq` per layer, computes cos/sin via `torch.outer(positions, inv_freq)` each forward. Saves memory vs precomputed `[max_seq_len, head_dim]` tables (sliding uses full RoPE, full uses proportional partial RoPE — head_dim and θ differ). -- **On-device Gumbel-max sampling** so the exported program emits a token - rather than a full logits tensor — keeps the runner GPU↔CPU traffic to a - single float per step. +- **Last-logits-only**: `lm_head` always runs on `x[:, -1, :]`, avoiding a + `(1, T, 262144)` matmul during prefill. +- **On-device Gumbel-max sampling** (CUDA) so the exported program emits a + token rather than logits — keeps GPU↔CPU traffic to a single float per + step. MLX samples on the host via `logits_to_token`. - **Final-logit softcap baked into the graph**, applied before sampling. - **Meta-device construction + assign-load** keeps peak memory small enough to load the 31B-parameter checkpoint on one machine. diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index f0aa2fac982..a690bd79230 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -470,20 +470,17 @@ def forward( self, tokens: torch.LongTensor, input_pos: torch.LongTensor, - temperature: Optional[torch.Tensor] = None, + temperature: torch.Tensor, ) -> torch.Tensor: """Run the model. Args: tokens: (B, T) token IDs. input_pos: (T,) absolute positions for RoPE / KV cache. - temperature: optional 1-D float tensor controlling on-device sampling. - When provided, returns sampled tokens (B, 1) via Gumbel-max; - when None (e.g. eager eval), returns full logits (B, T, V) with - soft-capping applied so callers see post-cap values. + temperature: 1-D float tensor for Gumbel-max sampling. Returns: - (B, 1) token IDs when sampling, else (B, T, V) float32 logits. + (B, 1) sampled token IDs as float. """ x = self.embed_tokens(tokens) * self.embed_normalizer @@ -492,13 +489,6 @@ def forward( x = layer(x, input_pos, sliding_mask, full_mask) x = self.norm(x) - - if temperature is None: - logits = self.lm_head(x).float() - cap = self.logit_softcap.float() - return torch.tanh(logits / cap) * cap - - # Decode-time fast path: only materialize logits for the last token. last = self.lm_head(x[:, -1, :]).float() cap = self.logit_softcap.float() last = torch.tanh(last / cap) * cap diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md index 31b1c43d574..2eacced4387 100644 --- a/examples/models/gemma4_31b/quant/README.md +++ b/examples/models/gemma4_31b/quant/README.md @@ -9,7 +9,8 @@ Quantization framework: **recipe → quantize → pack**. | `recipe.py` | **Policy** — what to quantize, what precision, which layers | nothing | | `quantize.py` | **Computation** — produces torchao subclass tensors | recipe, torchao | | `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | — | -| `pack_cuda.py` | **CUDA packing** — converts Int4Tensor to tinygemm format | pack | +| `pack_cuda.py` | **CUDA packing** — passes Int4Tensor/IntxUnpacked through for CUDA dispatch | pack | +| `pack_mlx.py` | **MLX packing** — converts Int4Tensor → IntxUnpacked, regroups per-axis embeddings | pack | | `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to torchao subclasses | torchao | ## Data flow @@ -48,7 +49,6 @@ The format is compatible with torchao's `save_pretrained` / `load_pretrained`. ## TODO - `pack_metal.py` — Metal backend packer. -- `pack_mlx.py` — MLX backend packer. - `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types. - Upstream `Int4TilePackedTo4dTensor.from_int4_tensor()` to torchao to replace the manual conversion in `pack_int4_for_cuda`. diff --git a/examples/models/gemma4_31b/quant/__init__.py b/examples/models/gemma4_31b/quant/__init__.py index 93efb69865f..7e9ab97a1bb 100644 --- a/examples/models/gemma4_31b/quant/__init__.py +++ b/examples/models/gemma4_31b/quant/__init__.py @@ -6,5 +6,6 @@ from .pack import ModulePackerFn, pack_model, pack_one # noqa: F401 from .pack_cuda import DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda # noqa: F401 +from .pack_mlx import DEFAULT_MLX_PACKERS, load_and_pack_for_mlx # noqa: F401 from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401 from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401 diff --git a/examples/models/gemma4_31b/quant/pack_mlx.py b/examples/models/gemma4_31b/quant/pack_mlx.py new file mode 100644 index 00000000000..63aeca426a8 --- /dev/null +++ b/examples/models/gemma4_31b/quant/pack_mlx.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""MLX packer: convert quantized weights to MLX-compatible format. + +MLX's ``QuantizedLinearHandler`` matches ``dequantize_affine → linear`` +in the exported graph. ``IntxUnpackedToInt8Tensor`` produces this +pattern naturally, but ``Int4Tensor`` does not (its dispatch calls +CUDA-specific mslk kernels). So INT4 weights are converted to +``IntxUnpackedToInt8Tensor(target_dtype=torch.int4)`` at pack time. + +The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. +""" + +import json + +import torch +import torch.nn as nn + +from .pack import ModulePackerFn, pack_model # noqa: F401 + +_MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32) + + +# --------------------------------------------------------------------------- +# Int4Tensor → IntxUnpackedToInt8Tensor conversion + + +def _int4_to_intx_unpacked(w: torch.Tensor) -> torch.Tensor: + """Convert an ``Int4Tensor`` to ``IntxUnpackedToInt8Tensor``. + + Int4Tensor stores qdata as nibble-packed uint8 ``(N, K/2)`` with + scale/zero transposed to ``(K//gs, N)``. IntxUnpackedToInt8Tensor + stores qdata as int8 ``(N, K)`` with scale/zero as ``(N, K//gs)``. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + + # Unpack nibbles: packed = even | (odd << 4), unsigned [0, 15] + p = w.qdata.to(torch.uint8) + low = (p & 0x0F).to(torch.int8) + high = ((p >> 4) & 0x0F).to(torch.int8) + qdata = torch.stack([low, high], dim=-1).reshape(w.shape) + + # Shift unsigned [0, 15] → signed [-8, 7] + qdata = qdata - 8 + + gs = w.block_size[-1] + + # Transpose scale/zero from (K//gs, N) → (N, K//gs) + scale = w.scale.t().contiguous() + zero_point = (w.zero_point - 8).t().contiguous() + + return IntxUnpackedToInt8Tensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=torch.int4, + block_size=(1, gs), + dtype=scale.dtype, + activation_quantization=None, + ) + + +# --------------------------------------------------------------------------- +# Embedding group_size regrouping + + +def _mlx_group_size(gs: int, K: int) -> int: + """Find an MLX-compatible group_size for the given weight group_size. + + If ``gs`` is already in {32, 64, 128}, return it. Otherwise find the + largest supported group_size that divides ``gs`` so per-axis scales can + be repeated to fill finer groups. + """ + if gs in _MLX_SUPPORTED_GROUP_SIZES: + return gs + for candidate in _MLX_SUPPORTED_GROUP_SIZES: + if gs % candidate == 0 and K % candidate == 0: + return candidate + raise ValueError( + f"MLX requires group_size in {set(_MLX_SUPPORTED_GROUP_SIZES)} " + f"(or a multiple thereof), got {gs}" + ) + + +def _regroup_intx(w: torch.Tensor, new_gs: int) -> torch.Tensor: + """Regroup an ``IntxUnpackedToInt8Tensor`` to a finer group_size.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + + old_gs = w.block_size[-1] + if old_gs % new_gs != 0: + raise ValueError( + f"new group_size {new_gs} must evenly divide old group_size {old_gs}" + ) + repeat_factor = old_gs // new_gs + N = w.qdata.shape[0] + n_groups = w.qdata.shape[-1] // new_gs + + scale = w.scale.repeat_interleave(repeat_factor, dim=-1).reshape(N, n_groups) + zero_point = w.zero_point.repeat_interleave(repeat_factor, dim=-1).reshape( + N, n_groups + ) + + return IntxUnpackedToInt8Tensor( + qdata=w.qdata, + scale=scale, + zero_point=zero_point, + target_dtype=w.target_dtype, + block_size=(1, new_gs), + dtype=w.dtype, + activation_quantization=w.activation_quantization, + ) + + +# --------------------------------------------------------------------------- +# Per-module packer + + +def pack_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: + """Pack a quantized weight for MLX. + + ``Int4Tensor`` is converted to ``IntxUnpackedToInt8Tensor`` so the + default dispatch produces the ``dequantize_affine → linear`` pattern + MLX expects. Regroups to a compatible group_size when needed (e.g. + per-axis group_size=5376 → group_size=128) since MLX's + ``parse_dequant_node`` only accepts group_size in {32, 64, 128}. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + w = weights["weight"] + if isinstance(w, Int4Tensor): + w = _int4_to_intx_unpacked(w) + if isinstance(w, IntxUnpackedToInt8Tensor): + gs = w.block_size[-1] + K = w.qdata.shape[-1] + target_gs = _mlx_group_size(gs, K) + if target_gs != gs: + w = _regroup_intx(w, target_gs) + module.weight = nn.Parameter(w, requires_grad=False) + + +DEFAULT_MLX_PACKERS: dict[type, ModulePackerFn] = { + nn.Linear: pack_for_mlx, + nn.Embedding: pack_for_mlx, +} + + +# --------------------------------------------------------------------------- +# Load + pack (I/O wrapper) + + +def load_and_pack_for_mlx( + path: str, + model: nn.Module, + packers: dict[type, ModulePackerFn] | None = None, +) -> None: + """Load a quantized safetensors file and pack for MLX. + + Streams one weight at a time via torchao's safetensors support. + """ + from safetensors import safe_open + from torchao.prototype.safetensors.safetensors_support import ( + unflatten_tensor_state_dict, + ) + + from .pack import pack_one + + _packers = packers or DEFAULT_MLX_PACKERS + with safe_open(path, framework="pt", device="cpu") as f: + metadata = f.metadata() + all_keys = list(f.keys()) + tensor_names = json.loads(metadata.get("tensor_names", "[]")) + + for name in tensor_names: + parts = name.rsplit(".", 1) + module_fqn = parts[0] if len(parts) > 1 else "" + weight_name = parts[-1] + prefix = ( + f"{module_fqn}._{weight_name}_" if module_fqn else f"_{weight_name}_" + ) + partial = {} + for key in all_keys: + if key.startswith(prefix) or key == name: + partial[key] = f.get_tensor(key) + result, _ = unflatten_tensor_state_dict(partial, metadata) + for fqn, value in result.items(): + pack_one(model, fqn, value, _packers) + + for fqn, p in model.named_parameters(): + if p.device.type == "meta": + raise RuntimeError( + f"Weight '{fqn}' not found in checkpoint " + f"(model/checkpoint version mismatch?)" + ) diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py new file mode 100644 index 00000000000..ffb2e0e2dd3 --- /dev/null +++ b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for quant/pack_mlx.py. No CUDA or MLX hardware required.""" + +import unittest + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.quant.pack import pack_model +from executorch.examples.models.gemma4_31b.quant.pack_mlx import ( + _int4_to_intx_unpacked, + _mlx_group_size, + DEFAULT_MLX_PACKERS, + pack_for_mlx, +) +from executorch.examples.models.gemma4_31b.quant.quantize import ( + dequantize_weight, + quantize_weight, +) +from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig + + +class TestInt4ToIntxConversion(unittest.TestCase): + """Int4Tensor → IntxUnpackedToInt8Tensor conversion.""" + + def test_symmetric_dequant_matches(self): + """Converted weight dequantizes to same values as original.""" + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + int4_w = quantize_weight(weight, config) + intx_w = _int4_to_intx_unpacked(int4_w) + + int4_dense = dequantize_weight(int4_w, torch.float32) + intx_dense = dequantize_weight(intx_w, torch.float32) + self.assertTrue( + torch.allclose(int4_dense, intx_dense, atol=1e-5), + f"max diff: {(int4_dense - intx_dense).abs().max():.6g}", + ) + + def test_asymmetric_dequant_matches(self): + torch.manual_seed(0) + weight = torch.randn(64, 128, dtype=torch.bfloat16) + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + int4_w = quantize_weight(weight, config) + intx_w = _int4_to_intx_unpacked(int4_w) + + int4_dense = dequantize_weight(int4_w, torch.float32) + intx_dense = dequantize_weight(intx_w, torch.float32) + self.assertTrue( + torch.allclose(int4_dense, intx_dense, atol=1e-5), + f"max diff: {(int4_dense - intx_dense).abs().max():.6g}", + ) + + def test_output_type_and_shape(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + torch.manual_seed(0) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + int4_w = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) + intx_w = _int4_to_intx_unpacked(int4_w) + + self.assertIsInstance(intx_w, IntxUnpackedToInt8Tensor) + self.assertEqual(intx_w.shape, torch.Size([128, 256])) + self.assertEqual(intx_w.qdata.shape, torch.Size([128, 256])) + self.assertEqual(intx_w.target_dtype, torch.int4) + + def test_different_group_sizes(self): + torch.manual_seed(0) + for gs in (32, 64, 128): + with self.subTest(group_size=gs): + config = QuantConfig( + bits=4, group_size=gs, symmetric=True, method="min_max" + ) + int4_w = quantize_weight( + torch.randn(64, 256, dtype=torch.bfloat16), config + ) + intx_w = _int4_to_intx_unpacked(int4_w) + self.assertEqual(intx_w.shape, torch.Size([64, 256])) + + def test_matmul_approximates_original(self): + torch.manual_seed(0) + weight = torch.randn(256, 128, dtype=torch.bfloat16) + x = torch.randn(1, 128, dtype=torch.bfloat16) + original_out = torch.nn.functional.linear(x, weight) + + config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") + int4_w = quantize_weight(weight, config) + intx_w = _int4_to_intx_unpacked(int4_w) + packed_out = torch.nn.functional.linear(x, intx_w.dequantize()) + + rel_error = ( + packed_out.float() - original_out.float() + ).abs().mean() / original_out.float().abs().mean() + self.assertLess(rel_error.item(), 0.15) + + +class TestPackLinearForMlx(unittest.TestCase): + def test_int4_converts_to_intx(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + module = nn.Linear(128, 64, bias=False) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + pack_for_mlx(module, {"weight": w}) + + self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + self.assertEqual(module.weight.shape, torch.Size([64, 128])) + self.assertFalse(module.weight.requires_grad) + + def test_int8_passes_through(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + module = nn.Linear(128, 64, bias=False) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) + self.assertIsInstance(w, IntxUnpackedToInt8Tensor) + pack_for_mlx(module, {"weight": w}) + + self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + self.assertEqual(module.weight.shape, torch.Size([64, 128])) + + def test_regroup_preserves_dequant(self): + """Linear with non-standard group_size regroups and dequantizes correctly.""" + torch.manual_seed(0) + weight = torch.randn(64, 256, dtype=torch.bfloat16) + config = QuantConfig(bits=8, group_size=256, symmetric=True, method="min_max") + w = quantize_weight(weight, config) + before = dequantize_weight(w, torch.float32) + + module = nn.Linear(256, 64, bias=False) + pack_for_mlx(module, {"weight": w}) + + self.assertEqual(module.weight.data.block_size, (1, 128)) + after = dequantize_weight(module.weight.data, torch.float32) + self.assertTrue( + torch.allclose(before, after, atol=1e-5), + f"max diff: {(before - after).abs().max():.6g}", + ) + + +class TestMlxGroupSize(unittest.TestCase): + def test_passthrough(self): + for gs in (32, 64, 128): + self.assertEqual(_mlx_group_size(gs, 256), gs) + + def test_regroup_5376(self): + self.assertEqual(_mlx_group_size(5376, 5376), 128) + + def test_regroup_256(self): + self.assertEqual(_mlx_group_size(256, 256), 128) + + def test_rejects_indivisible(self): + with self.assertRaises(ValueError): + _mlx_group_size(48, 48) + + +class TestPackEmbeddingForMlx(unittest.TestCase): + def test_compatible_passes_through(self): + module = nn.Embedding(100, 64) + config = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + pack_for_mlx(module, {"weight": w}) + self.assertEqual(module.weight.shape, torch.Size([100, 64])) + + def test_per_axis_regroups(self): + module = nn.Embedding(50, 256) + config = QuantConfig(bits=8, group_size=256, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(50, 256, dtype=torch.bfloat16), config) + pack_for_mlx(module, {"weight": w}) + self.assertEqual(module.weight.shape, torch.Size([50, 256])) + self.assertEqual(module.weight.data.block_size, (1, 128)) + + def test_int4_converts_to_intx(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + module = nn.Embedding(100, 64) + config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + w = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) + pack_for_mlx(module, {"weight": w}) + self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + self.assertEqual(module.weight.shape, torch.Size([100, 64])) + + +class TestPackModelMlx(unittest.TestCase): + def test_mixed_precision(self): + q4 = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") + q8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") + w4 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q4) + w8 = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), q8) + + state_dict = { + "q_proj.weight": w4, + "v_proj.weight": w8, + "norm.weight": torch.randn(64, dtype=torch.bfloat16), + } + + with torch.device("meta"): + model = nn.ModuleDict( + { + "q_proj": nn.Linear(128, 64, bias=False), + "v_proj": nn.Linear(128, 64, bias=False), + "norm": nn.LayerNorm(64, bias=False), + } + ) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + + self.assertEqual(model.q_proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.v_proj.weight.shape, torch.Size([64, 128])) + self.assertEqual(model.norm.weight.shape, torch.Size([64])) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/sampler.py b/examples/models/gemma4_31b/sampler.py index 45e4e17887a..690344fd2e4 100644 --- a/examples/models/gemma4_31b/sampler.py +++ b/examples/models/gemma4_31b/sampler.py @@ -8,33 +8,26 @@ Mirrors ``examples/models/qwen3_5_moe/sampler.py``: a single-output sampler that lets one exported program be re-driven with different temperatures -without re-export. ``temperature=None`` is a no-op (returns logits). +without re-export. """ -from typing import Optional - import torch def sample( logits: torch.Tensor, - temperature: Optional[torch.Tensor] = None, + temperature: torch.Tensor, ) -> torch.Tensor: """Draw a single token per batch row using the Gumbel-max trick. Args: logits: ``[B, V]`` float32 logits (already soft-capped if applicable). temperature: 0-D or 1-D float tensor; clamped to >= 1e-6 so a 0 - temperature still works ("near-greedy"). When ``None`` the call - short-circuits and returns ``logits`` unchanged. + temperature still works ("near-greedy"). Returns: - ``[B, 1]`` float32 token IDs (``argmax(logits/T + gumbel_noise)``), - or the unmodified logits when ``temperature`` is ``None``. + ``[B, 1]`` float32 token IDs (``argmax(logits/T + gumbel_noise)``). """ - if temperature is None: - return logits - logits = logits / temperature.clamp(min=1e-6) noise = torch.rand_like(logits) gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20) diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 0ff28aac415..505d6f7bdc1 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -108,30 +108,27 @@ def test_chunked_prefill_matches_sequential(self): torch.manual_seed(0) prompt = torch.randint(0, config.vocab_size, (1, prompt_len), device="cuda") + temp = torch.tensor([1e-6], dtype=torch.float32, device="cuda") + with torch.no_grad(): for i in range(prompt_len): tok = prompt[:, i : i + 1] pos = torch.tensor([i], dtype=torch.long, device="cuda") - logits_seq = model_seq(tok, pos, None) + token_seq = model_seq(tok, pos, temp) with torch.no_grad(): chunk1 = prompt[:, :buf_size] pos1 = torch.arange(buf_size, dtype=torch.long, device="cuda") - model_chunk(chunk1, pos1, None) + model_chunk(chunk1, pos1, temp) chunk2 = prompt[:, buf_size:] pos2 = torch.arange(buf_size, prompt_len, dtype=torch.long, device="cuda") - logits_chunk = model_chunk(chunk2, pos2, None) - - max_diff = (logits_seq[0, -1].float() - logits_chunk[0, -1].float()).abs().max() - self.assertTrue( - torch.allclose( - logits_seq[0, -1].float(), - logits_chunk[0, -1].float(), - atol=1e-2, - rtol=1e-3, - ), - f"Chunked prefill diverged: max_diff={max_diff:.4g}", + token_chunk = model_chunk(chunk2, pos2, temp) + + self.assertEqual( + int(token_seq.item()), + int(token_chunk.item()), + "Chunked prefill produced different token than sequential", ) diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py new file mode 100644 index 00000000000..0e62ab88e4b --- /dev/null +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""End-to-end MLX backend tests for the Gemma 4 31B-IT pipeline. + +Tests quantize → save → load → pack-for-MLX on a tiny model. +No CUDA or MLX hardware required. + +Usage: + python -m pytest examples/models/gemma4_31b/tests/test_mlx_pipeline.py -v +""" + +import json +import os +import tempfile +import unittest + +import torch +import torch.nn as nn + +from executorch.examples.models.gemma4_31b.model import Gemma4_31B +from executorch.examples.models.gemma4_31b.quant import ( + DEFAULT_MLX_PACKERS, + pack_model, + QuantConfig, + quantize_model, + QuantRecipe, + QuantRule, +) +from executorch.examples.models.gemma4_31b.tests.test_pipeline import ( + build_random_tiny_model, + config_dict, + save_checkpoint, + TINY_CONFIG, +) + +_INT4 = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") +_INT8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max") +_INT8_PER_AXIS = QuantConfig( + bits=8, group_size=TINY_CONFIG.hidden_size, symmetric=True, method="min_max" +) +_EDGE_LAYERS = set(range(3)) + +TINY_SENSITIVE_RECIPE = QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", _INT8_PER_AXIS), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.(v_proj|down_proj)\.weight", _INT8, layers=_EDGE_LAYERS), + QuantRule(r".*\.weight", _INT4), + ] +) + + +class TestMlxPipeline(unittest.TestCase): + """End-to-end: quantize → pack for MLX → forward.""" + + def test_pack_for_mlx(self): + """Quantize with sensitive recipe, pack for MLX, no meta weights.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + + for fqn, p in model.named_parameters(): + self.assertNotEqual(p.device.type, "meta", f"Weight '{fqn}' still on meta") + + def test_forward_after_pack(self): + """Model produces valid output after MLX packing.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + model.eval() + + from executorch.examples.models.gemma4_31b.model import ( + materialize_runtime_buffers, + ) + + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + tokens = torch.randint(0, TINY_CONFIG.vocab_size, (1, 1)) + input_pos = torch.tensor([0], dtype=torch.long) + temp = torch.tensor([1e-6], dtype=torch.float32) + + with torch.no_grad(): + out = model(tokens, input_pos, temp) + + self.assertEqual(out.shape, torch.Size([1, 1])) + self.assertFalse(torch.isnan(out).any()) + + def test_multi_token_forward(self): + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + model.eval() + + from executorch.examples.models.gemma4_31b.model import ( + materialize_runtime_buffers, + ) + + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + seq_len = 4 + tokens = torch.randint(0, TINY_CONFIG.vocab_size, (1, seq_len)) + input_pos = torch.arange(seq_len, dtype=torch.long) + temp = torch.tensor([1e-6], dtype=torch.float32) + + with torch.no_grad(): + out = model(tokens, input_pos, temp) + + self.assertEqual(out.shape, torch.Size([1, 1])) + self.assertFalse(torch.isnan(out).any()) + + def test_source_transforms_forward(self): + """Model produces valid output after MLX source transforms.""" + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + model.eval() + + from executorch.examples.models.gemma4_31b.mlx_source_transformations import ( + mlx_source_transformations, + ) + from executorch.examples.models.gemma4_31b.model import ( + materialize_runtime_buffers, + ) + + mlx_source_transformations(model, dtype=torch.bfloat16) + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + # After source transforms: signature is (tokens, input_pos) → (B, 1, V) + # Single-token decode + tokens = torch.randint(0, TINY_CONFIG.vocab_size, (1, 1)) + input_pos = torch.tensor([0], dtype=torch.long) + with torch.no_grad(): + out = model(tokens, input_pos) + self.assertEqual(out.shape, torch.Size([1, TINY_CONFIG.vocab_size])) + self.assertFalse(torch.isnan(out).any()) + self.assertFalse(torch.isinf(out).any()) + + # Multi-token prefill + seq_len = 4 + tokens = torch.randint(0, TINY_CONFIG.vocab_size, (1, seq_len)) + input_pos = torch.arange(seq_len, dtype=torch.long) + with torch.no_grad(): + out = model(tokens, input_pos) + self.assertEqual(out.shape, torch.Size([1, TINY_CONFIG.vocab_size])) + self.assertFalse(torch.isnan(out).any()) + + def test_source_transforms_use_mlx_ops(self): + """Verify the traced graph contains the expected MLX custom ops. + + Each attention layer should produce: + - 2× ``mlx.rope`` (q and k) + - 2× ``mlx.kv_cache_update`` (k and v) + - 1× ``mlx.custom_sdpa`` + """ + from executorch.examples.models.gemma4_31b.mlx_source_transformations import ( + mlx_source_transformations, + ) + from executorch.examples.models.gemma4_31b.model import ( + materialize_runtime_buffers, + ) + from torch.export import Dim, export + + model = build_random_tiny_model() + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + state_dict = quantize_model(model, TINY_SENSITIVE_RECIPE) + + with torch.device("meta"): + model = Gemma4_31B(TINY_CONFIG) + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + pack_model(model, state_dict, DEFAULT_MLX_PACKERS) + model.eval() + + mlx_source_transformations(model, dtype=torch.bfloat16) + materialize_runtime_buffers(model, dtype=torch.bfloat16) + + # Trace with dynamic seq_len matching the MLX export shape. + seq_dim = Dim("seq", min=1, max=8) + ep = export( + model, + (torch.tensor([[1, 2]]), torch.tensor([0, 1])), + dynamic_shapes=({1: seq_dim}, {0: seq_dim}), + strict=True, + ) + + op_counts = {"rope": 0, "kv_cache_update": 0, "custom_sdpa": 0} + for node in ep.graph.nodes: + if node.op != "call_function": + continue + name = str(node.target) + for op in op_counts: + if f"mlx.{op}" in name: + op_counts[op] += 1 + + n_layers = TINY_CONFIG.num_hidden_layers + self.assertEqual(op_counts["rope"], 2 * n_layers, f"got {op_counts}") + self.assertEqual(op_counts["kv_cache_update"], 2 * n_layers, f"got {op_counts}") + self.assertEqual(op_counts["custom_sdpa"], n_layers, f"got {op_counts}") + + def test_export_to_pte(self): + """Full export: quantize → pack → export with MLXPartitioner.""" + try: + from executorch.backends.mlx import MLXPartitioner # noqa: F401 + except ImportError: + self.skipTest("MLX backend not available") + + from executorch.examples.models.gemma4_31b.export import ( + export_and_lower, + load_prequantized_model, + ) + + with tempfile.TemporaryDirectory() as ckpt_dir, tempfile.TemporaryDirectory() as out_dir: + save_checkpoint(ckpt_dir) + with open(os.path.join(ckpt_dir, "config.json"), "w") as f: + json.dump(config_dict(), f) + + model, config = load_prequantized_model( + ckpt_dir, max_seq_len=TINY_CONFIG.max_seq_len, backend="mlx" + ) + export_and_lower(model, config, out_dir, backend="mlx") + self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) + + +if __name__ == "__main__": + unittest.main()