Skip to content

magpie-lang/torch-magpie

Repository files navigation

torch-magpie

A complete PyTorch reimplementation in Rust, purpose-built as the tensor backend for the Magpie compiler. torch-magpie provides GPU-accelerated deep learning on Apple Silicon through MLX and Metal, with a handle-based FFI surface that the Magpie compiler links against at compile time.

Abstract

Modern deep learning frameworks like PyTorch rely on eager, per-operation GPU dispatch, which introduces overhead for every tensor operation. torch-magpie takes a different approach: it fuses entire transformer layers into lazy MLX computation graphs that are submitted as single Metal command buffers, eliminating intermediate tensor allocations and per-op dispatch costs. Combined with GPU-resident KV-caches and memory-efficient tiled attention (via MLX's scaled dot-product attention kernel), torch-magpie achieves 1.4--2.9x speedups over PyTorch MPS on BERT training and GPT-2 inference, reaches parity on Llama-3.2-1B decode throughput, and uniquely supports 100k-token prefill sequences that PyTorch MPS cannot run at all.

Performance

Benchmarked on Apple M1 Pro (10-core CPU, 16-core GPU), PyTorch 2.8.0 MPS vs torch-magpie MLX:

Model PyTorch (MPS) torch-magpie (MLX) Speedup
BERT Training (6L/768H, batch=4) 155.9 ms/step 110.8 ms/step 1.41x
GPT-2 Decode (6L/768H, greedy) 128.6 tok/s 294.5 tok/s 2.29x
GPT-2 Prefill (128 tokens) 19.4 ms 6.8 ms 2.85x
Llama-3.2-1B Decode (real weights) 29.1 tok/s 28.5 tok/s ~parity
Llama-3.2-1B Prefill (32 tok) 81.1 ms 51.1 ms 1.59x
Llama-3.2-1B Prefill (100k tok) FAILED 316 tok/s torch-magpie only

The 100k-token prefill failure on PyTorch MPS is due to its SDPA implementation materializing the full O(S^2) attention matrix (~1,192 GiB), while torch-magpie uses MLX's tiled O(S)-memory attention kernel.

Architecture

torch-magpie is organized as a Rust workspace of 24 crates in 4 dependency layers:

Layer 3  torch_ffi_lib ─── Magpie FFI exports (staticlib, 68 extern "C" functions)
         torch_onnx, torch_ext, torch_backward_gen

Layer 2  torch_nn (170+ modules)   torch_optim (13 optimizers)
         torch_serialize            torch_linalg, torch_fft, torch_special
         torch_distributions        torch_sparse, torch_data
         torch_profiler             torch_tensorboard, torch_amp, torch_quant

Layer 1  torch_autograd ── reverse-mode autodiff engine
         torch_cpu ─────── CPU kernels (arithmetic, math, reduction, shape, indexing)
         torch_mlx ─────── Apple MLX + Metal gap-fill kernels
         torch_cuda ────── CUDA stubs    torch_rocm ── ROCm stubs

Layer 0  torch_core ────── Tensor, Storage, DType, Device, dispatch, allocator, RNG
         torch_shims ───── FFI bindings to vendor libraries

3-Tier GPU Dispatch

Every tensor operation follows a dispatch chain that maximizes GPU utilization:

  1. MLX -- lazy evaluation on Apple Silicon GPU via MLX library. Operations are deferred and batched into a single Metal command buffer.
  2. Metal -- custom Metal Shading Language kernels for operations not covered by MLX (batched matmul, elementwise ops, softmax).
  3. CPU -- Rust implementation as final fallback.

Key Design Decisions

  • Handle-based FFI: All cross-boundary types use i64 opaque handles (pointer-as-integer). Tensor handles are Box-managed; module handles use Magpie's ARC runtime (TypeId 2007).
  • Fused transformer layers: BERT encoder, GPT-2 decoder, and Llama decoder each have a fused path that operates entirely on MlxArrayHandle without intermediate Tensor allocations.
  • GPU-resident KV-cache: Key/value caches are stored as MlxArrayHandle on the GPU. No CPU round-trip during autoregressive decode.
  • Lazy evaluation + batch eval: MLX operations are deferred; batch_eval submits all pending work as one Metal command buffer at the end of each forward pass.

API Surface

Category Count
FFI-exported functions 68
Tensor operations 1,200+
Neural network modules 170+
Optimizers 13
LR schedulers 15
Data types 14
Device backends 4 (CPU, CUDA, ROCm, MLX)
Serialization formats 3 (Safetensors, GGUF, PyTorch .pt)

See DOCUMENTATION.md for the complete API reference mapping every torch-magpie endpoint to its Magpie IR calling convention, Rust implementation, and PyTorch equivalent.

Building

# Build the FFI static library (links against Magpie runtime)
cargo build --release -p torch_ffi_lib

# Run benchmarks (requires Magpie compiler for .mp files)
cd benchmark && ./run_llama_benchmark.sh

Prerequisites

  • Rust stable toolchain
  • macOS arm64 (Apple Silicon) for MLX/Metal acceleration
  • Magpie compiler (for compiling .mp benchmark files)
  • Llama weights at LLAMA_MODEL_PATH (for Llama benchmarks)

Project Structure

crates/
  torch_core/          Tensor, Storage, DType, Device, dispatch
  torch_autograd/      Reverse-mode autodiff engine
  torch_cpu/           CPU kernels (arithmetic, math, reduction, shape, indexing)
  torch_mlx/           MLX integration + Metal gap-fill kernels
  torch_nn/            Neural network modules (170+)
  torch_optim/         Optimizers (13) and LR schedulers (15)
  torch_serialize/     Safetensors, GGUF, PyTorch .pt I/O
  torch_ffi_lib/       Magpie FFI exports (staticlib)
  ...                  + 16 more crates (linalg, fft, special, distributions, etc.)

benchmark/
  magpie/              Magpie SSA IR benchmark files (.mp)
  pytorch/             PyTorch reference benchmarks
  report/              Benchmark results (JSON + markdown)

Documentation

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors