A complete PyTorch reimplementation in Rust, purpose-built as the tensor backend for the Magpie compiler. torch-magpie provides GPU-accelerated deep learning on Apple Silicon through MLX and Metal, with a handle-based FFI surface that the Magpie compiler links against at compile time.
Modern deep learning frameworks like PyTorch rely on eager, per-operation GPU dispatch, which introduces overhead for every tensor operation. torch-magpie takes a different approach: it fuses entire transformer layers into lazy MLX computation graphs that are submitted as single Metal command buffers, eliminating intermediate tensor allocations and per-op dispatch costs. Combined with GPU-resident KV-caches and memory-efficient tiled attention (via MLX's scaled dot-product attention kernel), torch-magpie achieves 1.4--2.9x speedups over PyTorch MPS on BERT training and GPT-2 inference, reaches parity on Llama-3.2-1B decode throughput, and uniquely supports 100k-token prefill sequences that PyTorch MPS cannot run at all.
Benchmarked on Apple M1 Pro (10-core CPU, 16-core GPU), PyTorch 2.8.0 MPS vs torch-magpie MLX:
| Model | PyTorch (MPS) | torch-magpie (MLX) | Speedup |
|---|---|---|---|
| BERT Training (6L/768H, batch=4) | 155.9 ms/step | 110.8 ms/step | 1.41x |
| GPT-2 Decode (6L/768H, greedy) | 128.6 tok/s | 294.5 tok/s | 2.29x |
| GPT-2 Prefill (128 tokens) | 19.4 ms | 6.8 ms | 2.85x |
| Llama-3.2-1B Decode (real weights) | 29.1 tok/s | 28.5 tok/s | ~parity |
| Llama-3.2-1B Prefill (32 tok) | 81.1 ms | 51.1 ms | 1.59x |
| Llama-3.2-1B Prefill (100k tok) | FAILED | 316 tok/s | torch-magpie only |
The 100k-token prefill failure on PyTorch MPS is due to its SDPA implementation materializing the full O(S^2) attention matrix (~1,192 GiB), while torch-magpie uses MLX's tiled O(S)-memory attention kernel.
torch-magpie is organized as a Rust workspace of 24 crates in 4 dependency layers:
Layer 3 torch_ffi_lib ─── Magpie FFI exports (staticlib, 68 extern "C" functions)
torch_onnx, torch_ext, torch_backward_gen
Layer 2 torch_nn (170+ modules) torch_optim (13 optimizers)
torch_serialize torch_linalg, torch_fft, torch_special
torch_distributions torch_sparse, torch_data
torch_profiler torch_tensorboard, torch_amp, torch_quant
Layer 1 torch_autograd ── reverse-mode autodiff engine
torch_cpu ─────── CPU kernels (arithmetic, math, reduction, shape, indexing)
torch_mlx ─────── Apple MLX + Metal gap-fill kernels
torch_cuda ────── CUDA stubs torch_rocm ── ROCm stubs
Layer 0 torch_core ────── Tensor, Storage, DType, Device, dispatch, allocator, RNG
torch_shims ───── FFI bindings to vendor libraries
Every tensor operation follows a dispatch chain that maximizes GPU utilization:
- MLX -- lazy evaluation on Apple Silicon GPU via MLX library. Operations are deferred and batched into a single Metal command buffer.
- Metal -- custom Metal Shading Language kernels for operations not covered by MLX (batched matmul, elementwise ops, softmax).
- CPU -- Rust implementation as final fallback.
- Handle-based FFI: All cross-boundary types use
i64opaque handles (pointer-as-integer). Tensor handles areBox-managed; module handles use Magpie's ARC runtime (TypeId 2007). - Fused transformer layers: BERT encoder, GPT-2 decoder, and Llama decoder each have a fused path that operates entirely on
MlxArrayHandlewithout intermediateTensorallocations. - GPU-resident KV-cache: Key/value caches are stored as
MlxArrayHandleon the GPU. No CPU round-trip during autoregressive decode. - Lazy evaluation + batch eval: MLX operations are deferred;
batch_evalsubmits all pending work as one Metal command buffer at the end of each forward pass.
| Category | Count |
|---|---|
| FFI-exported functions | 68 |
| Tensor operations | 1,200+ |
| Neural network modules | 170+ |
| Optimizers | 13 |
| LR schedulers | 15 |
| Data types | 14 |
| Device backends | 4 (CPU, CUDA, ROCm, MLX) |
| Serialization formats | 3 (Safetensors, GGUF, PyTorch .pt) |
See DOCUMENTATION.md for the complete API reference mapping every torch-magpie endpoint to its Magpie IR calling convention, Rust implementation, and PyTorch equivalent.
# Build the FFI static library (links against Magpie runtime)
cargo build --release -p torch_ffi_lib
# Run benchmarks (requires Magpie compiler for .mp files)
cd benchmark && ./run_llama_benchmark.sh- Rust stable toolchain
- macOS arm64 (Apple Silicon) for MLX/Metal acceleration
- Magpie compiler (for compiling
.mpbenchmark files) - Llama weights at
LLAMA_MODEL_PATH(for Llama benchmarks)
crates/
torch_core/ Tensor, Storage, DType, Device, dispatch
torch_autograd/ Reverse-mode autodiff engine
torch_cpu/ CPU kernels (arithmetic, math, reduction, shape, indexing)
torch_mlx/ MLX integration + Metal gap-fill kernels
torch_nn/ Neural network modules (170+)
torch_optim/ Optimizers (13) and LR schedulers (15)
torch_serialize/ Safetensors, GGUF, PyTorch .pt I/O
torch_ffi_lib/ Magpie FFI exports (staticlib)
... + 16 more crates (linalg, fft, special, distributions, etc.)
benchmark/
magpie/ Magpie SSA IR benchmark files (.mp)
pytorch/ PyTorch reference benchmarks
report/ Benchmark results (JSON + markdown)
- DOCUMENTATION.md -- Complete API reference (1,449 lines). Every endpoint listed with Magpie IR, Rust, and PyTorch signatures.
- benchmark/report/BENCHMARK_REPORT.md -- Detailed benchmark methodology and results.