Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/build-mlx-engine.yml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ jobs:
retention-days: 30

build-macos:
runs-on: macos-latest
# Pinned: macos-latest moved to macOS 26 (Xcode 26 / SDK 26), whose Metal
# toolchain compiles the metallib as MSL 4.0 — not loadable on the test
# runner's Metal runtime. macos-15 keeps the metallib at MSL 3.2.
runs-on: macos-15
needs: prepare-matrix
if: needs.prepare-matrix.outputs.should_build_macos == 'true'

Expand Down Expand Up @@ -342,7 +345,7 @@ jobs:
# timeout_minutes: 30
- os: macos-arm64
artifact: mlx-engine-macos-arm64
runs_on: '"macos-latest"'
runs_on: '"macos-15"'
install_rocm: false
model: mlx-community/Qwen3.5-0.8B-4bit
use_mtp: false
Expand Down
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ add_library(mlx-lm-common
src/common/quantize_utils.cpp
src/common/chat_template.cpp
src/common/gated_delta.cpp
src/common/graph_decode.cpp
src/llm/models/mtp_head.cpp
src/llm/models/mtp_moe.cpp
)
Expand Down Expand Up @@ -255,6 +256,15 @@ if(MLX_LM_BUILD_EXAMPLES)
add_executable(test_dynslice examples/test_dynslice.cpp)
target_link_libraries(test_dynslice PRIVATE mlx)

if(MLX_BUILD_ROCM)
# ROCm-only: link the decode-arena bridge (undefined on CPU/Metal).
add_executable(test_donate examples/test_donate.cpp)
target_link_libraries(test_donate PRIVATE mlx)

add_executable(test_arena examples/test_arena.cpp)
target_link_libraries(test_arena PRIVATE mlx)
endif()

add_executable(test_sdpa_ref examples/test_sdpa_ref.cpp)
target_link_libraries(test_sdpa_ref PRIVATE mlx)

Expand Down
60 changes: 60 additions & 0 deletions examples/test_arena.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Standalone check: the decode arena hands out identical device addresses for
// an identical allocation sequence across token resets. This determinism is the
// precondition for build-once HIP-graph relaunch.
#include <cstdio>
#include <mlx/mlx.h>

namespace mx = mlx::core;

namespace mlx::core {
bool decode_arena_begin(size_t capacity, int device, void* stream);
void decode_arena_reset();
void decode_arena_end();
size_t decode_arena_high_water();
bool decode_arena_overflowed();
}

int main() {
fprintf(stderr, "[arena] start\n");
mx::set_default_device(mx::Device::gpu);
fprintf(stderr, "[arena] device set; warmup...\n");
{ auto w = mx::add(mx::ones({4}), mx::ones({4})); mx::eval(w); }
fprintf(stderr, "[arena] warmup done; begin arena\n");

const size_t cap = size_t(256) * 1024 * 1024;
if (!mx::decode_arena_begin(cap, 0, nullptr)) {
printf("decode_arena_begin failed\n");
return 1;
}
fprintf(stderr, "[arena] begin ok\n");

// A small but multi-op "token": several allocations in a fixed order.
auto run = []() -> void* {
auto a = mx::ones({512, 512}, mx::float32);
auto b = mx::matmul(a, a);
auto c = mx::add(b, a);
auto d = mx::multiply(c, mx::array(2.0f));
mx::eval(d);
return static_cast<void*>(d.data<float>());
};

void* p1 = run();
size_t hw1 = mx::decode_arena_high_water();

mx::decode_arena_reset();
void* p2 = run();

mx::decode_arena_reset();
void* p3 = run();

size_t hw3 = mx::decode_arena_high_water();
bool overflow = mx::decode_arena_overflowed();
mx::decode_arena_end();

printf("p1=%p p2=%p p3=%p\n", p1, p2, p3);
printf("high_water token1=%zu token3=%zu overflow=%d\n", hw1, hw3,
(int)overflow);
bool ok = (p1 == p2) && (p2 == p3) && !overflow && (hw1 == hw3);
printf(ok ? "DETERMINISTIC OK\n" : "NONDETERMINISTIC FAIL\n");
return ok ? 0 : 2;
}
21 changes: 21 additions & 0 deletions examples/test_donate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Does mx::slice_update donate (in-place) or copy? Compare the buffer pointer
// before/after when the input is the sole owner (std::move pattern).
#include <mlx/mlx.h>
#include <cstdio>
namespace mx = mlx::core;
int main() {
mx::set_default_device(mx::Device::gpu);
int B=1,H=2,CAP=512,D=256, slot=5;
auto buf = mx::zeros({B,H,CAP,D}, mx::float32);
mx::eval(buf);
void* p0 = (void*)buf.data<float>();
auto pos = mx::array({slot}, {1}, mx::int32);
// sole-owner update (mirrors update_at_pos after std::move)
auto k = std::move(buf);
auto out = mx::slice_update(k, mx::full({B,H,1,D},3.0f,mx::float32), pos, std::vector<int>{2});
mx::eval(out);
void* p1 = (void*)out.data<float>();
printf("before=%p after=%p %s\n", p0, p1,
p0==p1 ? "DONATED (in-place)" : "COPIED (new buffer)");
return 0;
}
28 changes: 28 additions & 0 deletions include/mlx-lm/common/gated_delta.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ std::pair<mlx::core::array, mlx::core::array> gated_delta_update(
mlx::core::array inplace_write(const mlx::core::array& dst,
const mlx::core::array& src);

// In-place KV-cache slice write: writes new_kv [B,H,N,D] into cache [B,H,ALLOC,D]
// at [:,:,offset:offset+N,:]. The output ALIASES the cache buffer, so no copy is
// made (replaces slice_update, whose donation fails under the async pipeline and
// copies the whole cache → variable copy count → non-replayable decode graph).
mlx::core::array kv_inplace_update(
const mlx::core::array& cache, const mlx::core::array& new_kv, int offset);

// FlashQLA-style fused GDN decode step (T=1): folds q/k-RMSNorm + beta/g +
// the delta recurrence into ONE kernel (replaces rms_norm(q)+rms_norm(k)+
// compiled beta/g + gated_delta_step). q,k: [B,1,Hk,Dk], v: [B,1,Hv,Dv],
// a,b: [B,1,Hv], a_log,dt_bias: [Hv], q_norm_w,k_norm_w: [Dk],
// state: [B,Hv,Dv,Dk]. Returns (y [B,1,Hv,Dv], new_state). The output gate
// (norm_, reduces over Dv) stays separate. MLX_GDN_FUSED2_MXOPS=1 falls back.
std::pair<mlx::core::array, mlx::core::array> gdn_fused_decode(
const mlx::core::array& q, const mlx::core::array& k,
const mlx::core::array& v, const mlx::core::array& a,
const mlx::core::array& b, const mlx::core::array& a_log,
const mlx::core::array& dt_bias,
const mlx::core::array& q_norm_w, const mlx::core::array& k_norm_w,
const mlx::core::array& state);

// Fused GDN conv1d decode step: causal depthwise conv (KS taps) + silu + state
// shift in one kernel. conv_state [B,KS-1,CD], qkv [B,1,CD], weight [CD,1,KS].
// Returns (conv_out [B,1,CD] silu'd, new_state [B,KS-1,CD]). Replaces the
Expand All @@ -70,6 +91,13 @@ std::pair<mlx::core::array, mlx::core::array> add_rms_norm(
const mlx::core::array& weight,
float eps);

// Fused gated RMSNorm (GDN/attention output gate): silu(gate) * rmsnorm(x) *
// weight in one kernel. Replaces rms_norm + sigmoid + multiply. x,gate: [..,H],
// weight: [H]. MLX_FUSED_NORM_MXOPS=1 falls back to the op chain.
mlx::core::array gated_rms_norm(
const mlx::core::array& x, const mlx::core::array& gate,
const mlx::core::array& weight, float eps);

// Fused MoE router (norm_topk_prob): top-k of the router logits + softmax over
// just those k, in one kernel (replaces argpartition+slice+take_along+softmax).
// Returns (indices [.., k] uint32, scores [.., k]). ROCm fast path needs
Expand Down
9 changes: 9 additions & 0 deletions include/mlx-lm/common/generate.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ class TokenIterator {
std::optional<int> max_tokens_;
int token_count_ = 0;

// --- Pure-relaunch graph decode (build-once, deterministic arena) ---
// State machine: 0 warmup, 1 record parity A, 2 record parity B, 3 replay,
// 9 disabled. Two graphs (one per pos&1) bake the GDN ping-pong state slots.
int pure_graph_state_ = 0;
int pure_graph_cap_ = 0; // reserved KV capacity
int pure_pos_ = 0; // host mirror of the device decode position
// Run one decode step under the pure-graph path; returns the sampled token.
mlx::core::array step_pure_graph(const LMInput::Text& previous);

// KV cache quantization parameters.
std::optional<int> kv_bits_;
int kv_group_size_ = 64;
Expand Down
38 changes: 38 additions & 0 deletions include/mlx-lm/common/graph_decode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright © 2025
#pragma once

#include <mlx/mlx.h>

namespace mlx_lm {

// Persistent device-position scalar for HIP-graph decode (fixed-address [1] int32).
mlx::core::array& graph_decode_pos();

// In-place device write of the absolute position via a raw kernel.
void set_graph_decode_pos(int offset);

// Advance the device position in place by delta (loop-owned, between replays).
void advance_graph_decode_pos(int delta);

// When true, the decode loop owns the position; false on the eager path.
bool graph_external_pos();
void set_graph_external_pos(bool on);

// Persistent device input-token buffer for build-once graph decode (fixed
// address [1,1] int32). The recorded graph's embedding gather reads this buffer;
// the loop feeds each new token into it (device copy) so the buffer is never
// reallocated between relaunches.
mlx::core::array& graph_decode_input();

// Device-copy the freshly-sampled token (a [1]/[1,1] int32 array) into the
// fixed-address input buffer. Runs on-stream, ordered after the producing step.
void set_graph_decode_input_from(mlx::core::array& token);

// Whether HIP-graph decode is active (ROCm only; false elsewhere).
bool graph_decode_enabled();

// True only while the single decode step is being captured (and during replay).
bool graph_capturing();
void set_graph_capturing(bool on);

} // namespace mlx_lm
71 changes: 70 additions & 1 deletion include/mlx-lm/common/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ class KVCacheSimple : public KVCacheBase<KVCacheSimple> {
int trim_impl(int n);

public:
// Device-position write for build-once HIP-graph decode: write new_keys/values
// at slot `pos` (a [1] int32 device array) into the pre-allocated buffer via
// DynamicSliceUpdate (offset advances device-side on replay). The full buffer
// is returned; the caller attends over it with a device-pos length mask.
// Requires the buffer pre-allocated to capacity (no growth during decode).
std::pair<mlx::core::array, mlx::core::array> update_at_pos(
const mlx::core::array& new_keys, const mlx::core::array& new_values,
const mlx::core::array& pos);

// Pre-grow the buffer to `capacity` columns (axis 2) with zeros, keeping the
// logical offset. Required before update_at_pos so device-offset writes never
// grow/realloc the buffer (which would break the build-once graph's baked
// pointers). No-op if not yet populated or already at capacity.
void reserve_to(int capacity);

KVCacheSimple() = default;
explicit KVCacheSimple(int initial_capacity) : initial_capacity_(initial_capacity) {}
KVCacheSimple(int initial_capacity, int reserve)
Expand Down Expand Up @@ -161,7 +176,11 @@ class QuantizedKVCache : public KVCacheBase<QuantizedKVCache> {
// Mamba-style state space model cache.
// Stores conv_state (index 0) and ssm_state (index 1).
class MambaCache {
std::optional<mlx::core::array> states_[2];
// [0]=conv_state, [1]=ssm_state. [2]/[3] are scratch "next-state" buffers for
// build-once graph decode: the recorded graph reads [0]/[1] and writes the new
// state to [2]/[3] (different buffers → no read==write hazard on relaunch);
// the decode loop copies [2]->[0], [3]->[1] between relaunches.
std::optional<mlx::core::array> states_[4];
int offset_ = 0;

public:
Expand Down Expand Up @@ -291,6 +310,28 @@ class CompoundCache {
update(const mlx::core::array& keys, const mlx::core::array& values) {
return std::visit([&](auto& c) { return c.update(keys, values); }, kv_);
}
std::pair<mlx::core::array, mlx::core::array>
update_at_pos(const mlx::core::array& keys, const mlx::core::array& values,
const mlx::core::array& pos) {
return std::visit([&](auto& c)
-> std::pair<mlx::core::array, mlx::core::array> {
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, KVCacheSimple>) {
return c.update_at_pos(keys, values, pos);
} else {
throw std::runtime_error(
"update_at_pos unsupported for rotating KV sub-cache");
}
}, kv_);
}
void reserve_to(int capacity) {
std::visit([&](auto& c) {
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, KVCacheSimple>) {
c.reserve_to(capacity);
}
}, kv_);
}
bool is_trimmable() const {
return std::visit([](const auto& c) { return c.is_trimmable(); }, kv_);
}
Expand Down Expand Up @@ -343,6 +384,34 @@ class KVCache {
return std::visit([&](auto& c) { return c.update(keys, values); }, impl_);
}

// Device-offset write at `pos` (build-once graph decode). Only the simple /
// compound KV caches support it.
std::pair<mlx::core::array, mlx::core::array>
update_at_pos(const mlx::core::array& keys, const mlx::core::array& values,
const mlx::core::array& pos) {
return std::visit([&](auto& c)
-> std::pair<mlx::core::array, mlx::core::array> {
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, KVCacheSimple> ||
std::is_same_v<T, CompoundCache>) {
return c.update_at_pos(keys, values, pos);
} else {
throw std::runtime_error(
"update_at_pos unsupported for this cache type");
}
}, impl_);
}

void reserve_to(int capacity) {
std::visit([&](auto& c) {
using T = std::decay_t<decltype(c)>;
if constexpr (std::is_same_v<T, KVCacheSimple> ||
std::is_same_v<T, CompoundCache>) {
c.reserve_to(capacity);
}
}, impl_);
}

bool is_trimmable() const {
return std::visit([](const auto& c) { return c.is_trimmable(); }, impl_);
}
Expand Down
Loading
Loading