Add Mac metal MLX inference#50
Open
Ma-Dan wants to merge 16 commits into
Open
Conversation
added 5 commits
April 27, 2026 21:31
- _interpolate_pos_embed: force float32 when extracting pos_embed to numpy for scipy.ndimage.zoom (rejects float16 arrays with RuntimeError) - Add max_special_tokens cap to CausalAttentionMLX / StreamingBlock; thread kv_cache_max_special_frames through AggregatorMLX, CameraHeadMLX, GCTStreamMLX; expose as --max-special-frames CLI flag (default: None) k_special grows 6 tokens/frame indefinitely without this cap; set to e.g. 100 for sequences >300 frames with small --kv-sliding-window
Replace double sequential fancy indexing (x[:,y0,:,:][:,:,x0,:]) with precomputed flat 1D gather indices into a reshaped [B, H*W, C] view. Indices computed once per (H,W,h,w) with numpy and cached as concrete MLX int32 arrays; removes 8 lazy-graph arange/clip/floor ops per call. Also factors the bilinear blend: 6 muls vs 8, and uses 2.1-2.4x speedup. End-to-end: 1.5 fps -> 1.9 fps on 294x518, kv-sw=8, float16.
In steady state (cache full at scale_frames+sliding_window frames, k_special capped at max_special_tokens), all tensor shapes are constant per frame. mx.compile(_step) lets Metal reuse the same compiled program every frame instead of recompiling the 24-block attention graph. Conditions to enter compiled path: - num_frame_per_block == 1 (single frame streaming) - kv_cache not skip_append - rope_cos provided (2D RoPE, the normal streaming path) - keep_special=True and max_special_tokens set (requires --max-special-frames) - kv_cache['k'].shape[2] == scale_frames + sliding_window (cache full) - kv_cache['k_special'].shape[2] == max_special_tokens (special cap reached) _make_steady_fn builds the compiled pure function lazily on first use. It captures layer weights in the closure; all ops are pure MLX. k_special rotation (constant shape): drop oldest patch_start_idx tokens, append newly evicted frame's special tokens. Verified: compiled path output matches manual reference with 0.0 abs diff.
|
Hi, Severity: action required | Category: reliability How to fix: Add MLX optional dependency Agent prompt to fix - you can give this to your LLM of choice:
Spotted by Qodo code review - free for open-source projects. |
added 10 commits
April 30, 2026 21:00
…l sync points Three incremental optimizations that together give ~7% end-to-end speedup (1.88fps → 2.01fps measured in component benchmark, sw=16 float16): 1. layers.py _make_steady_fn: eliminate 32 MB k_cat intermediate tensor k_cat = concat([k_cache, k_new]) was the largest intermediate allocation per steady-state attention call (32.5 MB). Replace with: - evict_k drawn directly from old k_cache at index sf (zero alloc) - new_k_cache = 3-way concat([k_cache[:sf], k_cache[sf+1:], k_new]) Reduces concatenates 8→6 per block; 24 blocks × 32 MB = 768 MB less memory traffic per frame. 2. aggregator.py _get_rope: cache RoPE cos/sin tables Patch positions are purely a function of image resolution; cos/sin are identical for every streaming frame of the same size. Cache per (B, S, H, W, head_dim, dtype) key; subsequent frames get pre-evaluated, dtype-cast arrays with zero recompute. 3. aggregator.py __call__: remove 4 intermediate mx.eval sync points mx.eval(tokens_global) was called after output groups 4/11/17/23 to break the lazy graph into Metal-compilable segments. With compiled steady-state global blocks, the full 24-pair graph is small enough to evaluate lazily in one pass. Measured 7 ms/frame savings from fewer Metal→CPU roundtrips. Correctness: confirmed by running full inference_streaming on 20 frames and verifying early-frame outputs are bit-for-bit identical with and without the compiled path.
…ools With compiled steady-state global blocks, letting the backbone fuse lazily into the full aggregator graph gives ~3ms improvement and reduces Metal round-trips. Previous reasoning (break 72-block graph into two segments) is no longer compelling. bench_compile.py / bench_depth.py: per-component timing utilities that require N_WARMUP>=27 (16 to fill sliding window + 10 to cap k_special) to measure true steady-state performance. Verified steady-state baseline: sw=16 float16 max-special-frames=10 agg≈452ms cam≈18ms depth≈63ms total≈534ms → 1.87fps Depth breakdown: refinenets=30ms, output_conv=28ms, resize=4ms (bilinear+conv dominated, ConvTranspose2d is not a bottleneck)
Aggregator, camera, and depth head lazy graphs now fuse into one graph evaluated at mx.eval(frame_out). Removes a Metal CPU round-trip between aggregator and heads; MLX lazy evaluation correctly tracks all data dependencies transitively. bench_e2e.py: end-to-end per-frame timing through model() call. Measured improvement (sw=16 float16 steady-state, N=40): Before: 534ms → 1.87 fps (agg eval forced mid-forward) After: 527ms → 1.90 fps (fully fused, lower std too)
…v-sliding-window token count
…ens_per_frame for non-square images
|
this is great!, thanks for sharing! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
requirements