diff --git a/.github/workflows/build-mlx-engine.yml b/.github/workflows/build-mlx-engine.yml index d36c74dc..c172ab2c 100644 --- a/.github/workflows/build-mlx-engine.yml +++ b/.github/workflows/build-mlx-engine.yml @@ -225,7 +225,10 @@ jobs: retention-days: 30 build-macos: - runs-on: macos-latest + # Pinned: macos-latest moved to macOS 26 (Xcode 26 / SDK 26), whose Metal + # toolchain compiles the metallib as MSL 4.0 — not loadable on the test + # runner's Metal runtime. macos-15 keeps the metallib at MSL 3.2. + runs-on: macos-15 needs: prepare-matrix if: needs.prepare-matrix.outputs.should_build_macos == 'true' @@ -342,7 +345,7 @@ jobs: # timeout_minutes: 30 - os: macos-arm64 artifact: mlx-engine-macos-arm64 - runs_on: '"macos-latest"' + runs_on: '"macos-15"' install_rocm: false model: mlx-community/Qwen3.5-0.8B-4bit use_mtp: false diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e7cc029..760e5a6e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,6 +122,7 @@ add_library(mlx-lm-common src/common/quantize_utils.cpp src/common/chat_template.cpp src/common/gated_delta.cpp + src/common/graph_decode.cpp src/llm/models/mtp_head.cpp src/llm/models/mtp_moe.cpp ) @@ -255,6 +256,15 @@ if(MLX_LM_BUILD_EXAMPLES) add_executable(test_dynslice examples/test_dynslice.cpp) target_link_libraries(test_dynslice PRIVATE mlx) + if(MLX_BUILD_ROCM) + # ROCm-only: link the decode-arena bridge (undefined on CPU/Metal). + add_executable(test_donate examples/test_donate.cpp) + target_link_libraries(test_donate PRIVATE mlx) + + add_executable(test_arena examples/test_arena.cpp) + target_link_libraries(test_arena PRIVATE mlx) + endif() + add_executable(test_sdpa_ref examples/test_sdpa_ref.cpp) target_link_libraries(test_sdpa_ref PRIVATE mlx) diff --git a/examples/test_arena.cpp b/examples/test_arena.cpp new file mode 100644 index 00000000..65ad73fd --- /dev/null +++ b/examples/test_arena.cpp @@ -0,0 +1,60 @@ +// Standalone check: the decode arena hands out identical device addresses for +// an identical allocation sequence across token resets. This determinism is the +// precondition for build-once HIP-graph relaunch. +#include +#include + +namespace mx = mlx::core; + +namespace mlx::core { +bool decode_arena_begin(size_t capacity, int device, void* stream); +void decode_arena_reset(); +void decode_arena_end(); +size_t decode_arena_high_water(); +bool decode_arena_overflowed(); +} + +int main() { + fprintf(stderr, "[arena] start\n"); + mx::set_default_device(mx::Device::gpu); + fprintf(stderr, "[arena] device set; warmup...\n"); + { auto w = mx::add(mx::ones({4}), mx::ones({4})); mx::eval(w); } + fprintf(stderr, "[arena] warmup done; begin arena\n"); + + const size_t cap = size_t(256) * 1024 * 1024; + if (!mx::decode_arena_begin(cap, 0, nullptr)) { + printf("decode_arena_begin failed\n"); + return 1; + } + fprintf(stderr, "[arena] begin ok\n"); + + // A small but multi-op "token": several allocations in a fixed order. + auto run = []() -> void* { + auto a = mx::ones({512, 512}, mx::float32); + auto b = mx::matmul(a, a); + auto c = mx::add(b, a); + auto d = mx::multiply(c, mx::array(2.0f)); + mx::eval(d); + return static_cast(d.data()); + }; + + void* p1 = run(); + size_t hw1 = mx::decode_arena_high_water(); + + mx::decode_arena_reset(); + void* p2 = run(); + + mx::decode_arena_reset(); + void* p3 = run(); + + size_t hw3 = mx::decode_arena_high_water(); + bool overflow = mx::decode_arena_overflowed(); + mx::decode_arena_end(); + + printf("p1=%p p2=%p p3=%p\n", p1, p2, p3); + printf("high_water token1=%zu token3=%zu overflow=%d\n", hw1, hw3, + (int)overflow); + bool ok = (p1 == p2) && (p2 == p3) && !overflow && (hw1 == hw3); + printf(ok ? "DETERMINISTIC OK\n" : "NONDETERMINISTIC FAIL\n"); + return ok ? 0 : 2; +} diff --git a/examples/test_donate.cpp b/examples/test_donate.cpp new file mode 100644 index 00000000..eb3d6338 --- /dev/null +++ b/examples/test_donate.cpp @@ -0,0 +1,21 @@ +// Does mx::slice_update donate (in-place) or copy? Compare the buffer pointer +// before/after when the input is the sole owner (std::move pattern). +#include +#include +namespace mx = mlx::core; +int main() { + mx::set_default_device(mx::Device::gpu); + int B=1,H=2,CAP=512,D=256, slot=5; + auto buf = mx::zeros({B,H,CAP,D}, mx::float32); + mx::eval(buf); + void* p0 = (void*)buf.data(); + auto pos = mx::array({slot}, {1}, mx::int32); + // sole-owner update (mirrors update_at_pos after std::move) + auto k = std::move(buf); + auto out = mx::slice_update(k, mx::full({B,H,1,D},3.0f,mx::float32), pos, std::vector{2}); + mx::eval(out); + void* p1 = (void*)out.data(); + printf("before=%p after=%p %s\n", p0, p1, + p0==p1 ? "DONATED (in-place)" : "COPIED (new buffer)"); + return 0; +} diff --git a/include/mlx-lm/common/gated_delta.h b/include/mlx-lm/common/gated_delta.h index 8e249282..054ec928 100644 --- a/include/mlx-lm/common/gated_delta.h +++ b/include/mlx-lm/common/gated_delta.h @@ -52,6 +52,27 @@ std::pair gated_delta_update( mlx::core::array inplace_write(const mlx::core::array& dst, const mlx::core::array& src); +// In-place KV-cache slice write: writes new_kv [B,H,N,D] into cache [B,H,ALLOC,D] +// at [:,:,offset:offset+N,:]. The output ALIASES the cache buffer, so no copy is +// made (replaces slice_update, whose donation fails under the async pipeline and +// copies the whole cache → variable copy count → non-replayable decode graph). +mlx::core::array kv_inplace_update( + const mlx::core::array& cache, const mlx::core::array& new_kv, int offset); + +// FlashQLA-style fused GDN decode step (T=1): folds q/k-RMSNorm + beta/g + +// the delta recurrence into ONE kernel (replaces rms_norm(q)+rms_norm(k)+ +// compiled beta/g + gated_delta_step). q,k: [B,1,Hk,Dk], v: [B,1,Hv,Dv], +// a,b: [B,1,Hv], a_log,dt_bias: [Hv], q_norm_w,k_norm_w: [Dk], +// state: [B,Hv,Dv,Dk]. Returns (y [B,1,Hv,Dv], new_state). The output gate +// (norm_, reduces over Dv) stays separate. MLX_GDN_FUSED2_MXOPS=1 falls back. +std::pair gdn_fused_decode( + const mlx::core::array& q, const mlx::core::array& k, + const mlx::core::array& v, const mlx::core::array& a, + const mlx::core::array& b, const mlx::core::array& a_log, + const mlx::core::array& dt_bias, + const mlx::core::array& q_norm_w, const mlx::core::array& k_norm_w, + const mlx::core::array& state); + // Fused GDN conv1d decode step: causal depthwise conv (KS taps) + silu + state // shift in one kernel. conv_state [B,KS-1,CD], qkv [B,1,CD], weight [CD,1,KS]. // Returns (conv_out [B,1,CD] silu'd, new_state [B,KS-1,CD]). Replaces the @@ -70,6 +91,13 @@ std::pair add_rms_norm( const mlx::core::array& weight, float eps); +// Fused gated RMSNorm (GDN/attention output gate): silu(gate) * rmsnorm(x) * +// weight in one kernel. Replaces rms_norm + sigmoid + multiply. x,gate: [..,H], +// weight: [H]. MLX_FUSED_NORM_MXOPS=1 falls back to the op chain. +mlx::core::array gated_rms_norm( + const mlx::core::array& x, const mlx::core::array& gate, + const mlx::core::array& weight, float eps); + // Fused MoE router (norm_topk_prob): top-k of the router logits + softmax over // just those k, in one kernel (replaces argpartition+slice+take_along+softmax). // Returns (indices [.., k] uint32, scores [.., k]). ROCm fast path needs diff --git a/include/mlx-lm/common/generate.h b/include/mlx-lm/common/generate.h index c1f10202..00690696 100644 --- a/include/mlx-lm/common/generate.h +++ b/include/mlx-lm/common/generate.h @@ -328,6 +328,15 @@ class TokenIterator { std::optional max_tokens_; int token_count_ = 0; + // --- Pure-relaunch graph decode (build-once, deterministic arena) --- + // State machine: 0 warmup, 1 record parity A, 2 record parity B, 3 replay, + // 9 disabled. Two graphs (one per pos&1) bake the GDN ping-pong state slots. + int pure_graph_state_ = 0; + int pure_graph_cap_ = 0; // reserved KV capacity + int pure_pos_ = 0; // host mirror of the device decode position + // Run one decode step under the pure-graph path; returns the sampled token. + mlx::core::array step_pure_graph(const LMInput::Text& previous); + // KV cache quantization parameters. std::optional kv_bits_; int kv_group_size_ = 64; diff --git a/include/mlx-lm/common/graph_decode.h b/include/mlx-lm/common/graph_decode.h new file mode 100644 index 00000000..ab809fb9 --- /dev/null +++ b/include/mlx-lm/common/graph_decode.h @@ -0,0 +1,38 @@ +// Copyright © 2025 +#pragma once + +#include + +namespace mlx_lm { + +// Persistent device-position scalar for HIP-graph decode (fixed-address [1] int32). +mlx::core::array& graph_decode_pos(); + +// In-place device write of the absolute position via a raw kernel. +void set_graph_decode_pos(int offset); + +// Advance the device position in place by delta (loop-owned, between replays). +void advance_graph_decode_pos(int delta); + +// When true, the decode loop owns the position; false on the eager path. +bool graph_external_pos(); +void set_graph_external_pos(bool on); + +// Persistent device input-token buffer for build-once graph decode (fixed +// address [1,1] int32). The recorded graph's embedding gather reads this buffer; +// the loop feeds each new token into it (device copy) so the buffer is never +// reallocated between relaunches. +mlx::core::array& graph_decode_input(); + +// Device-copy the freshly-sampled token (a [1]/[1,1] int32 array) into the +// fixed-address input buffer. Runs on-stream, ordered after the producing step. +void set_graph_decode_input_from(mlx::core::array& token); + +// Whether HIP-graph decode is active (ROCm only; false elsewhere). +bool graph_decode_enabled(); + +// True only while the single decode step is being captured (and during replay). +bool graph_capturing(); +void set_graph_capturing(bool on); + +} // namespace mlx_lm diff --git a/include/mlx-lm/common/kv_cache.h b/include/mlx-lm/common/kv_cache.h index 84714474..e5546b01 100644 --- a/include/mlx-lm/common/kv_cache.h +++ b/include/mlx-lm/common/kv_cache.h @@ -56,6 +56,21 @@ class KVCacheSimple : public KVCacheBase { int trim_impl(int n); public: + // Device-position write for build-once HIP-graph decode: write new_keys/values + // at slot `pos` (a [1] int32 device array) into the pre-allocated buffer via + // DynamicSliceUpdate (offset advances device-side on replay). The full buffer + // is returned; the caller attends over it with a device-pos length mask. + // Requires the buffer pre-allocated to capacity (no growth during decode). + std::pair update_at_pos( + const mlx::core::array& new_keys, const mlx::core::array& new_values, + const mlx::core::array& pos); + + // Pre-grow the buffer to `capacity` columns (axis 2) with zeros, keeping the + // logical offset. Required before update_at_pos so device-offset writes never + // grow/realloc the buffer (which would break the build-once graph's baked + // pointers). No-op if not yet populated or already at capacity. + void reserve_to(int capacity); + KVCacheSimple() = default; explicit KVCacheSimple(int initial_capacity) : initial_capacity_(initial_capacity) {} KVCacheSimple(int initial_capacity, int reserve) @@ -161,7 +176,11 @@ class QuantizedKVCache : public KVCacheBase { // Mamba-style state space model cache. // Stores conv_state (index 0) and ssm_state (index 1). class MambaCache { - std::optional states_[2]; + // [0]=conv_state, [1]=ssm_state. [2]/[3] are scratch "next-state" buffers for + // build-once graph decode: the recorded graph reads [0]/[1] and writes the new + // state to [2]/[3] (different buffers → no read==write hazard on relaunch); + // the decode loop copies [2]->[0], [3]->[1] between relaunches. + std::optional states_[4]; int offset_ = 0; public: @@ -291,6 +310,28 @@ class CompoundCache { update(const mlx::core::array& keys, const mlx::core::array& values) { return std::visit([&](auto& c) { return c.update(keys, values); }, kv_); } + std::pair + update_at_pos(const mlx::core::array& keys, const mlx::core::array& values, + const mlx::core::array& pos) { + return std::visit([&](auto& c) + -> std::pair { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return c.update_at_pos(keys, values, pos); + } else { + throw std::runtime_error( + "update_at_pos unsupported for rotating KV sub-cache"); + } + }, kv_); + } + void reserve_to(int capacity) { + std::visit([&](auto& c) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + c.reserve_to(capacity); + } + }, kv_); + } bool is_trimmable() const { return std::visit([](const auto& c) { return c.is_trimmable(); }, kv_); } @@ -343,6 +384,34 @@ class KVCache { return std::visit([&](auto& c) { return c.update(keys, values); }, impl_); } + // Device-offset write at `pos` (build-once graph decode). Only the simple / + // compound KV caches support it. + std::pair + update_at_pos(const mlx::core::array& keys, const mlx::core::array& values, + const mlx::core::array& pos) { + return std::visit([&](auto& c) + -> std::pair { + using T = std::decay_t; + if constexpr (std::is_same_v || + std::is_same_v) { + return c.update_at_pos(keys, values, pos); + } else { + throw std::runtime_error( + "update_at_pos unsupported for this cache type"); + } + }, impl_); + } + + void reserve_to(int capacity) { + std::visit([&](auto& c) { + using T = std::decay_t; + if constexpr (std::is_same_v || + std::is_same_v) { + c.reserve_to(capacity); + } + }, impl_); + } + bool is_trimmable() const { return std::visit([](const auto& c) { return c.is_trimmable(); }, impl_); } diff --git a/src/common/gated_delta.cpp b/src/common/gated_delta.cpp index ad2ae8a2..61f5772c 100644 --- a/src/common/gated_delta.cpp +++ b/src/common/gated_delta.cpp @@ -217,6 +217,34 @@ static mx::fast::CustomKernelFunction& get_gdn_kernel() { } // In-place copy of `src` into `dst`'s buffer (output 0 aliases input 0). +// In-place KV-cache slice write: writes new_kv [B,H,N,D] into cache [B,H,ALLOC,D] +// at [:,:,offset:offset+N,:]. Output ALIASES the cache buffer (no copy, no +// donation check) — `offset` is a runtime scalar so the kernel/topology stays +// fixed token-to-token (this is what makes the decode graph replayable). +static const char* kv_inplace_update_hip_source = R"( + long idx = (long)blockIdx.x * blockDim.x + threadIdx.x; + long total = (long)B * H * N * D; + if (idx >= total) return; + int d = idx % D; + long t = idx / D; + int i = t % N; + long bh = t / N; + int off = offset[0]; + long dst = (bh * (long)ALLOC + (off + i)) * D + d; + out[dst] = new_kv[idx]; +)"; + +static mx::fast::CustomKernelFunction& get_kv_inplace_update_kernel() { + static auto kernel = mx::fast::hip_kernel( + "kv_inplace_update", + {"cache", "new_kv", "offset"}, + {"out"}, + kv_inplace_update_hip_source, + /*header=*/"", /*ensure_row_contiguous=*/true, /*shared_memory=*/0, + /*output_input_aliases=*/{{0, 0}}); + return kernel; +} + static mx::fast::CustomKernelFunction& get_inplace_copy_kernel() { static auto kernel = mx::fast::hip_kernel( "inplace_copy", @@ -299,6 +327,40 @@ static mx::fast::CustomKernelFunction& get_add_rms_norm_kernel() { return kernel; } +// Fused gated RMSNorm (GDN/attention output gate): out = silu(gate) * +// rmsnorm(x) * weight, in one pass. Replaces rms_norm + sigmoid + mul. +static const char* gated_rms_norm_hip_source = R"( + int r = blockIdx.y * blockDim.y + threadIdx.y; // row + int lane = threadIdx.x; // 0..31 + if (r >= N) return; + const InT* xr = x + (long)r * H; + const InT* gr = gate + (long)r * H; + InT* outr = out + (long)r * H; + float ss = 0.0f; + for (int h = lane; h < H; h += 32) { + float s = static_cast(xr[h]); + ss += s * s; + } + for (int off = 16; off > 0; off >>= 1) + ss += __shfl_xor(ss, off); + float scale = rsqrtf(ss / (float)H + eps[0]); + for (int h = lane; h < H; h += 32) { + float gv = static_cast(gr[h]); + float silu = gv / (1.0f + expf(-gv)); + outr[h] = static_cast(static_cast(xr[h]) * scale * + static_cast(weight[h]) * silu); + } +)"; + +static mx::fast::CustomKernelFunction& get_gated_rms_norm_kernel() { + static auto kernel = mx::fast::hip_kernel( + "gated_rms_norm", + {"x", "gate", "weight", "eps"}, + {"out"}, + gated_rms_norm_hip_source); + return kernel; +} + // Fused MoE router (norm_topk_prob): one warp per row finds the top-K of E // logits and softmaxes over just those K, in a single pass — replaces // argpartition(block_sort) + slice + take_along + softmax. Each lane owns @@ -350,6 +412,82 @@ static mx::fast::CustomKernelFunction& get_moe_route_kernel() { return kernel; } +// FlashQLA-style fused GDN decode (T=1): folds beta/g + q/k-RMSNorm + the delta +// recurrence into ONE kernel, replacing rms_norm(q)+rms_norm(k)+compiled_beta_g +// + gated_delta_step (4-5 launches/GDN-layer -> 1). One warp per (batch,v-head); +// the 32 dk-lanes hold Dk so q/k RMSNorm is a warp __shfl reduction. The output +// gate (norm_, reduces over Dv across blocks) stays separate. eps = 1e-6. +static const char* gdn_fused_decode_hip_source = R"( + int n = blockIdx.z * blockDim.z + threadIdx.z; + int b_idx = n / Hv; + int hv_idx = n % Hv; + int hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + int dk_idx = threadIdx.x; + int dv_idx = blockIdx.y * blockDim.y + threadIdx.y; + + const InT* q_ = q + (long)b_idx * Hk * Dk + hk_idx * Dk; + const InT* k_ = k + (long)b_idx * Hk * Dk + hk_idx * Dk; + const InT* v_ = v + (long)b_idx * Hv * Dv + hv_idx * Dv; + const InT* i_state = state_in + ((long)n * Dv + dv_idx) * Dk; + InT* o_state = state_out + ((long)n * Dv + dv_idx) * Dk; + + // beta, g (per head) + float beta = 1.0f / (1.0f + expf(-static_cast(b[(long)b_idx * Hv + hv_idx]))); + float sp = logf(expf(static_cast(a[(long)b_idx * Hv + hv_idx]) + + static_cast(dt_bias[hv_idx])) + 1.0f); + float g = expf(-expf(static_cast(a_log[hv_idx])) * sp); + + // load q/k, RMSNorm over Dk (warp reduce across the 32 dk-lanes). + // Column layout (s = dk_idx + 32*i): consecutive lanes hit consecutive Dk + // -> coalesced loads + independent per-i ops for dual-issue. + float ql[n_per_t], kl[n_per_t]; + float sq = 0.0f, sk = 0.0f; + #pragma unroll + for (int i = 0; i < n_per_t; ++i) { + int s = dk_idx + 32 * i; + ql[i] = static_cast(q_[s]); kl[i] = static_cast(k_[s]); + sq += ql[i] * ql[i]; sk += kl[i] * kl[i]; + } + for (int o = 16; o > 0; o >>= 1) { sq += __shfl_xor(sq, o); sk += __shfl_xor(sk, o); } + float scq = rsqrtf(sq / (float)Dk + 1e-6f); + float sck = rsqrtf(sk / (float)Dk + 1e-6f); + #pragma unroll + for (int i = 0; i < n_per_t; ++i) { + int s = dk_idx + 32 * i; + ql[i] = ql[i] * scq * static_cast(q_norm_w[s]); + kl[i] = kl[i] * sck * static_cast(k_norm_w[s]); + } + + // state load (column layout, coalesced) + float state[n_per_t]; + #pragma unroll + for (int i = 0; i < n_per_t; ++i) state[i] = static_cast(i_state[dk_idx + 32 * i]); + + // recurrence (single timestep) + float kv_mem = 0.0f; + #pragma unroll + for (int i = 0; i < n_per_t; ++i) { state[i] *= g; kv_mem += state[i] * kl[i]; } + for (int o = 16; o > 0; o >>= 1) kv_mem += __shfl_xor(kv_mem, o); + float delta = (static_cast(v_[dv_idx]) - kv_mem) * beta; + float out = 0.0f; + #pragma unroll + for (int i = 0; i < n_per_t; ++i) { state[i] += kl[i] * delta; out += state[i] * ql[i]; } + for (int o = 16; o > 0; o >>= 1) out += __shfl_xor(out, o); + if (dk_idx == 0) y[(long)b_idx * Hv * Dv + hv_idx * Dv + dv_idx] = static_cast(out); + #pragma unroll + for (int i = 0; i < n_per_t; ++i) o_state[dk_idx + 32 * i] = static_cast(state[i]); +)"; + +static mx::fast::CustomKernelFunction& get_gdn_fused_decode_kernel() { + static auto kernel = mx::fast::hip_kernel( + "gdn_fused_decode", + {"q", "k", "v", "b", "a", "a_log", "dt_bias", "q_norm_w", "k_norm_w", "state_in"}, + {"y", "state_out"}, + gdn_fused_decode_hip_source); + return kernel; +} + // --------------------------------------------------------------------------- // gatedDeltaKernel — dispatch the fused HIP kernel // --------------------------------------------------------------------------- @@ -678,6 +816,39 @@ std::pair gated_delta_update( return gated_delta_ops(q, k, v, g, beta, s, mask, inplace_state); } +std::pair gdn_fused_decode( + const mx::array& q, const mx::array& k, const mx::array& v, + const mx::array& a, const mx::array& b, + const mx::array& a_log, const mx::array& dt_bias, + const mx::array& q_norm_w, const mx::array& k_norm_w, + const mx::array& state) +{ + int B = q.shape(0); + int Hk = q.shape(2), Dk = q.shape(3); + int Hv = v.shape(2), Dv = v.shape(3); +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + static const bool force_fallback = std::getenv("MLX_GDN_FUSED2_MXOPS") != nullptr; + if (!force_fallback) { + auto t = q.dtype(); + auto al = mx::astype(a_log, mx::float32); + auto db = mx::astype(dt_bias, mx::float32); + auto results = get_gdn_fused_decode_kernel()( + {q, k, v, b, a, al, db, q_norm_w, k_norm_w, state}, + {{B, 1, Hv, Dv}, state.shape()}, + {t, t}, + {32, Dv, B * Hv}, // grid (total threads, Metal-style) + {32, 4, 1}, // threadgroup + {{"InT", t}, {"Dk", Dk}, {"Dv", Dv}, {"Hk", Hk}, {"Hv", Hv}}, + std::nullopt, true, {}); + return {results[0], results[1]}; + } +#endif + auto qn = mx::fast::rms_norm(q, q_norm_w, 1e-6f); + auto kn = mx::fast::rms_norm(k, k_norm_w, 1e-6f); + return gated_delta_update(qn, kn, v, a, b, a_log, dt_bias, state, + std::nullopt, false); +} + std::pair gdn_conv_step( const mx::array& conv_state, // [B, KS-1, CD] const mx::array& qkv, // [B, 1, CD] @@ -738,6 +909,31 @@ std::pair add_rms_norm( return {s, mx::fast::rms_norm(s, weight, eps)}; } +mx::array gated_rms_norm( + const mx::array& x, const mx::array& gate, + const mx::array& weight, float eps) +{ + int H = x.shape(-1); + int N = static_cast(x.size() / H); + auto t = x.dtype(); +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + static const bool force_mxops = std::getenv("MLX_FUSED_NORM_MXOPS") != nullptr; + if (!force_mxops) { + auto results = get_gated_rms_norm_kernel()( + {x, gate, weight, mx::array(eps)}, + {x.shape()}, + {t}, + {32, N, 1}, // grid (total threads): one wave per row + {32, 1, 1}, // threadgroup + {{"InT", t}, {"H", H}, {"N", N}}, + std::nullopt, true, {}); + return results[0]; + } +#endif + auto normed = mx::fast::rms_norm(x, weight, eps); + return mx::multiply(mx::multiply(gate, mx::sigmoid(gate)), normed); +} + std::pair moe_route(const mx::array& logits, int k) { int E = logits.shape(-1); int rows = static_cast(logits.size() / E); @@ -766,6 +962,29 @@ std::pair moe_route(const mx::array& logits, int k) { } // In-place write of src into dst's device buffer (same total element count). +mx::array kv_inplace_update( + const mx::array& cache, const mx::array& new_kv, int offset) +{ +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + int B = cache.shape(0), H = cache.shape(1); + int ALLOC = cache.shape(2), D = cache.shape(3); + int N = new_kv.shape(2); + long total = (long)B * H * N * D; + auto res = get_kv_inplace_update_kernel()( + {cache, new_kv, mx::array(offset, mx::int32)}, + {cache.shape()}, {cache.dtype()}, + {static_cast(total), 1, 1}, {256, 1, 1}, + {{"B", B}, {"H", H}, {"ALLOC", ALLOC}, {"D", D}, {"N", N}}, + std::nullopt, true, {}); + return res[0]; +#else + mx::Shape st{0, 0, offset, 0}; + mx::Shape sp{cache.shape(0), cache.shape(1), offset + new_kv.shape(2), + cache.shape(3)}; + return mx::slice_update(cache, new_kv, st, sp, mx::Shape{1, 1, 1, 1}); +#endif +} + mx::array inplace_write(const mx::array& dst, const mx::array& src) { #if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM int n = static_cast(src.size()); diff --git a/src/common/generate.cpp b/src/common/generate.cpp index 016b5cbf..f429d14d 100644 --- a/src/common/generate.cpp +++ b/src/common/generate.cpp @@ -16,8 +16,19 @@ #if defined(MLX_BUILD_ROCM) // Decode-mode toggle (defined in mlx/backend/rocm/eval.cpp; declared here to // avoid pulling HIP headers into engine code). +#include namespace mlx::core { void gpu_set_graph_decode_mode(bool v); +// Build-once pure-relaunch decode + deterministic arena (rocm backend bridge). +void decode_pure_record(int slot); +void decode_pure_replay(int slot); +void decode_pure_off(); +size_t decode_pure_chain_len(int slot); +bool decode_arena_begin(size_t capacity, int device, void* stream); +void decode_arena_reset(); +void decode_arena_end(); +bool decode_arena_overflowed(); +void gpu_buffer_copy(array& dst, array& src); } // namespace mlx::core #endif @@ -345,6 +356,110 @@ mx::array TokenIterator::step(const LMInput::Text& previous) { return convert_to_token(result.logits); } +#if defined(MLX_BUILD_ROCM) +// Build-once pure-relaunch decode step. State machine: +// 0 warmup — engage device-pos; warm mx::compile caches (no record) +// 1 record — record the per-token graph chain once +// 2 replay — relaunch the recorded chain every token +// 9 disabled — fell back to the normal path (arena overflow / mismatch) +// One graph suffices: the GatedDeltaNet recurrent state lives in a single static +// buffer updated in place (no parity ping-pong), and position + input token are +// injected each step via fixed-address device buffers. +mx::array TokenIterator::step_pure_graph(const LMInput::Text& previous) { + StreamGuard sg(generation_stream()); + namespace mc = mlx::core; + + static const size_t arena_bytes = [] { + const char* e = std::getenv("MLX_DECODE_ARENA_MB"); + return size_t(e ? std::atoll(e) : 1024) << 20; + }(); + static const bool noreplay = std::getenv("MLX_PURE_NOREPLAY") != nullptr; + + LMInput::Text in(mlx_lm::graph_decode_input()); // [1,1] int32, fixed addr + + // Feed input + advance position via IMMEDIATE launches (loop-owned, between + // relaunches) — never recorded graph nodes. + mc::gpu_set_graph_decode_mode(false); + mx::array prev_tok = previous.tokens; + mlx_lm::set_graph_decode_input_from(prev_tok); // device copy -> fixed buffer + if (pure_graph_state_ == 0) { + mlx_lm::set_graph_external_pos(true); + int off = 0; + for (auto& c : cache_) off = std::max(off, c.offset()); + mlx_lm::set_graph_decode_pos(off); + pure_pos_ = off; + for (auto& c : cache_) c.reserve_to(pure_graph_cap_); + } else { + mlx_lm::advance_graph_decode_pos(1); + pure_pos_ += 1; + } + // Move GDN scratch next-state [2]/[3] -> read state [0]/[1] (immediate). + static const bool cpdbg = std::getenv("MLX_COPY_DEBUG") != nullptr; + int n_mamba = 0, n_scratch = 0; + if (pure_graph_state_ >= 1) { + for (auto& c : cache_) { + auto* m = c.as_mamba(); + if (!m) continue; + n_mamba++; + if ((*m)[2].has_value()) { + n_scratch++; + mc::gpu_buffer_copy((*m)[0].value(), (*m)[2].value()); + mc::gpu_buffer_copy((*m)[1].value(), (*m)[3].value()); + } + } + if (cpdbg) fprintf(stderr, "[cp] mamba=%d scratch=%d\n", n_mamba, n_scratch); + } + mc::gpu_set_graph_decode_mode(true); + + // Single static recurrent-state buffer (in-place RMW) -> ONE graph, no + // parity. Record once, then relaunch the same chain every token. + if (!noreplay) { + if (pure_graph_state_ == 1) { + mc::decode_arena_begin(arena_bytes, 0, nullptr); + mc::decode_arena_reset(); + mc::decode_pure_record(0); + } else if (pure_graph_state_ == 2) { + mc::decode_arena_reset(); + mc::decode_pure_replay(0); + } + } + + auto result = context_.call_fn( + in, cache_.empty() ? nullptr : &cache_, + state_.has_value() ? &state_.value() : nullptr); + state_ = result.state; + auto token = convert_to_token(result.logits); + // Force-eval token + GDN scratch states (the loop reads their raw buffers). + std::vector ev{token}; + for (auto& c : cache_) { + auto* m = c.as_mamba(); + if (m && (*m)[2].has_value()) { ev.push_back((*m)[2].value()); ev.push_back((*m)[3].value()); } + } + mx::eval(ev); + + static const bool pure_dbg = std::getenv("MLX_PURE_DEBUG") != nullptr; + if (pure_dbg) { + fprintf(stderr, "[pure] state=%d pos=%d in=%d sampled=%d\n", + pure_graph_state_, pure_pos_, + mlx_lm::graph_decode_input().item(), token.item()); + } + + auto disable = [&]() { + mc::decode_pure_off(); + mc::decode_arena_end(); + mlx_lm::set_graph_external_pos(false); + pure_graph_state_ = 9; + }; + if (pure_graph_state_ == 0) { + pure_graph_state_ = 1; // next token records + } else if (pure_graph_state_ == 1) { + if (mc::decode_arena_overflowed()) disable(); + else pure_graph_state_ = 2; // recorded -> replay + } + return token; +} +#endif + // --------------------------------------------------------------------------- // TokenIterator — prepare (prompt prefill) // --------------------------------------------------------------------------- @@ -812,6 +927,28 @@ std::optional TokenIterator::next() { // Standard path: single token generation. static const bool g_sync_decode = std::getenv("MLX_SYNC_DECODE") != nullptr; + +#if defined(MLX_BUILD_ROCM) + // Build-once pure-relaunch graph decode (opt-in, qwen35-moe device-pos path). + static const bool pure_enabled = + std::getenv("MLX_DECODE_GRAPH_PURE") != nullptr; + if (pure_enabled && pure_graph_state_ != 9 && !cache_.empty()) { + if (pure_graph_cap_ == 0) { + int off = 0; + for (auto& c : cache_) off = std::max(off, c.offset()); + int remaining = max_tokens_.has_value() + ? std::max(0, max_tokens_.value() - token_count_) : 256; + pure_graph_cap_ = off + remaining + 8; + } + auto previous_y = y_; + auto token = step_pure_graph(previous_y); + y_ = LMInput::Text(token); + token_count_++; + measure_prefill_boundary_(); + return token.item(); + } +#endif + auto previous_y = y_; auto token = step(previous_y); y_ = LMInput::Text(token); diff --git a/src/common/graph_decode.cpp b/src/common/graph_decode.cpp new file mode 100644 index 00000000..13364f00 --- /dev/null +++ b/src/common/graph_decode.cpp @@ -0,0 +1,92 @@ +// Copyright © 2025 +#include "mlx-lm/common/graph_decode.h" +#include + +namespace mx = mlx::core; + +// In-place device-scalar kernels (ROCm backend): mutate the pos buffer contents +// without reallocating, keeping the captured graph's baked address valid. +namespace mlx::core { +void gpu_kv_pos_set(array& pos, int v); +void gpu_kv_pos_increment(array& pos, int delta); +void gpu_scalar_copy_i32(array& dst, array& src); +} + +namespace mlx_lm { + +static bool g_external = false; +static bool g_capturing = false; + +// Constructed lazily on first use (not at static-init time, before --device +// selection). +mx::array& graph_decode_pos() { + static mx::array* g_pos = nullptr; + if (g_pos == nullptr) { + g_pos = new mx::array(mx::zeros({1}, mx::int32)); + mx::eval(*g_pos); + } + return *g_pos; +} + +void set_graph_decode_pos(int offset) { + // Mutate the pos buffer in place via a raw kernel. +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + auto& p = graph_decode_pos(); + mx::gpu_kv_pos_set(p, offset); + mx::synchronize(mx::default_stream(mx::default_device())); +#else + auto& p = graph_decode_pos(); + p = mx::slice_update(p, mx::broadcast_to(mx::array(offset, mx::int32), p.shape()), + mx::Shape(p.ndim(), 0), p.shape()); + mx::eval(p); +#endif +} + +// Advance the device position in place by delta (loop-owned, between replays). +void advance_graph_decode_pos(int delta) { +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + auto& p = graph_decode_pos(); + mx::gpu_kv_pos_increment(p, delta); +#else + set_graph_decode_pos(0); // non-ROCm has no graph path +#endif +} + +bool graph_external_pos() { return g_external; } +void set_graph_external_pos(bool on) { g_external = on; } + +// Fixed-address [1,1] int32 input-token buffer. Constructed lazily (after device +// selection) and kept resident so its device address never changes. +mx::array& graph_decode_input() { + static mx::array* g_input = nullptr; + if (g_input == nullptr) { + g_input = new mx::array(mx::zeros({1, 1}, mx::int32)); + mx::eval(*g_input); + } + return *g_input; +} + +void set_graph_decode_input_from(mx::array& token) { +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + auto& dst = graph_decode_input(); + // token may be [1] or [1,1]; the kernel copies element 0 either way. + mx::gpu_scalar_copy_i32(dst, token); +#else + (void)token; +#endif +} + +bool graph_capturing() { return g_capturing; } +void set_graph_capturing(bool on) { g_capturing = on; } + +bool graph_decode_enabled() { +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + // Opt-in during bring-up. + static const bool on = std::getenv("MLX_DECODE_GRAPH") != nullptr; + return on; +#else + return false; +#endif +} + +} // namespace mlx_lm diff --git a/src/common/kv_cache.cpp b/src/common/kv_cache.cpp index fc3f2ad0..32d0134c 100644 --- a/src/common/kv_cache.cpp +++ b/src/common/kv_cache.cpp @@ -1,8 +1,10 @@ // Copyright © 2024-2025 Apple Inc. — Ported to C++ #include +#include #include #include +#include namespace mlx_lm { @@ -63,10 +65,21 @@ KVCacheSimple::update_impl( int current_alloc = keys_.value().shape(2); if (offset_ + n_new <= current_alloc) { - keys_ = mx::slice_update(keys_.value(), new_keys, - mx::Shape{0, 0, offset_, 0}, mx::Shape{B, H, offset_ + n_new, D}); - values_ = mx::slice_update(values_.value(), new_values, - mx::Shape{0, 0, offset_, 0}, mx::Shape{B, H, offset_ + n_new, D}); + // In-place slice write (output aliases the cache buffer) instead of + // slice_update, whose COW donation fails under the async one-behind + // pipeline and copies the whole cache — a variable per-token copy count + // that makes the decode graph non-replayable. MLX_KV_INPLACE_OFF opts out. + static const bool kv_inplace = + std::getenv("MLX_KV_INPLACE_OFF") == nullptr; + if (kv_inplace) { + keys_ = kv_inplace_update(keys_.value(), new_keys, offset_); + values_ = kv_inplace_update(values_.value(), new_values, offset_); + } else { + keys_ = mx::slice_update(keys_.value(), new_keys, + mx::Shape{0, 0, offset_, 0}, mx::Shape{B, H, offset_ + n_new, D}); + values_ = mx::slice_update(values_.value(), new_values, + mx::Shape{0, 0, offset_, 0}, mx::Shape{B, H, offset_ + n_new, D}); + } offset_ += n_new; return {mx::slice(keys_.value(), mx::Shape{0,0,0,0}, mx::Shape{B,H,offset_,D}), mx::slice(values_.value(), mx::Shape{0,0,0,0}, mx::Shape{B,H,offset_,D})}; @@ -128,6 +141,42 @@ void KVCacheSimple::set_position(size_t pos) { trim_impl(delta); } +std::pair KVCacheSimple::update_at_pos( + const mlx::core::array& new_keys, const mlx::core::array& new_values, + const mlx::core::array& pos) { + // DynamicSliceUpdate at the device-side `pos` (axis 2). The buffer must be + // pre-allocated to capacity; the offset advances device-side so the built + // graph relaunches correctly as the loop advances pos. std::move releases the + // cache's reference so slice_update can donate (update in place) — keeping the + // buffer at a FIXED address, which the build-once graph's nodes bake into. + auto k = std::move(keys_.value()); + auto v = std::move(values_.value()); + keys_ = mx::slice_update(k, new_keys, pos, {2}); + values_ = mx::slice_update(v, new_values, pos, {2}); + offset_ += new_keys.shape(2); + return {keys_.value(), values_.value()}; +} + +void KVCacheSimple::reserve_to(int capacity) { + if (!keys_.has_value()) { + return; + } + auto& k = keys_.value(); + int cur = k.shape(2); + if (cur >= capacity) { + return; + } + int pad = capacity - cur; + auto kpad = k.shape(); + kpad[2] = pad; + auto vpad = values_.value().shape(); + vpad[2] = pad; + keys_ = mx::concatenate({k, mx::zeros(kpad, k.dtype())}, 2); + values_ = mx::concatenate( + {values_.value(), mx::zeros(vpad, values_.value().dtype())}, 2); + mx::eval(keys_.value(), values_.value()); +} + // --- RotatingKVCache --- std::pair diff --git a/src/llm/models/qwen35_moe.cpp b/src/llm/models/qwen35_moe.cpp index 71c249af..ae30cd45 100644 --- a/src/llm/models/qwen35_moe.cpp +++ b/src/llm/models/qwen35_moe.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -279,17 +280,38 @@ mx::array Qwen35MoEAttention::operator()(const mx::array& x, // RoPE with partial rotary factor. int offset = cache ? cache->offset() : 0; - queries = mx::fast::rope(queries, rope_dims_, false, rope_theta_, 1.0f, offset); - keys = mx::fast::rope(keys, rope_dims_, false, rope_theta_, 1.0f, offset); - // KV cache update + SDPA. + // Device-position decode path (build-once HIP graph): RoPE offset, KV write + // slot, and causal mask all read a fixed-address [1] int32 device buffer so + // the same graph relaunches correctly as the loop advances pos device-side. + static const bool devpos_env = std::getenv("MLX_DECODE_DEVICE_POS") != nullptr; + bool gmode = (mlx_lm::graph_external_pos() || devpos_env) && L == 1 && cache; + mx::array output(0.0f); - if (cache) { - auto [k, v] = cache->update(keys, values); - keys = k; - values = v; + if (gmode) { + auto& pos = mlx_lm::graph_decode_pos(); // fixed-address [1] int32 + queries = mx::fast::rope(queries, rope_dims_, false, rope_theta_, 1.0f, pos); + keys = mx::fast::rope(keys, rope_dims_, false, rope_theta_, 1.0f, pos); + auto [k, v] = cache->update_at_pos(keys, values, pos); + int CAP = k.shape(2); + auto cols = mx::arange(0, CAP, mx::int32); + float ninf = -std::numeric_limits::infinity(); + auto addmask = mx::astype( + mx::reshape(mx::where(mx::less_equal(cols, pos), + mx::array(0.0f), mx::array(ninf)), + {1, 1, 1, CAP}), + x.dtype()); + output = sdpa(queries, k, v, scale_, AttentionMask::from_array(addmask)); + } else { + queries = mx::fast::rope(queries, rope_dims_, false, rope_theta_, 1.0f, offset); + keys = mx::fast::rope(keys, rope_dims_, false, rope_theta_, 1.0f, offset); + if (cache) { + auto [k, v] = cache->update(keys, values); + keys = k; + values = v; + } + output = sdpa(queries, keys, values, scale_, mask); } - output = sdpa(queries, keys, values, scale_, mask); output = mx::reshape(mx::transpose(output, {0, 2, 1, 3}), {B, L, -1}); // Swift: oProj(sigmoidMultiply(output, gate)) @@ -360,6 +382,16 @@ void Qwen35MoEGatedDeltaNet::ensure_in_proj_fused() { in_proj_fused_weight_); } +// In-place overwrite of dst with src via native slice_update donation (start=0). +static mlx::core::array gdn_state_overwrite_(mlx::core::array dst, + const mlx::core::array& src) { + int nd = dst.ndim(); + std::vector axes(nd); + for (int i = 0; i < nd; ++i) axes[i] = i; + return mx::slice_update(std::move(dst), src, + mx::zeros({nd}, mx::int32), axes); +} + mx::array Qwen35MoEGatedDeltaNet::operator()( const mx::array& inputs, const std::optional& mask, @@ -393,6 +425,14 @@ mx::array Qwen35MoEGatedDeltaNet::operator()( // Conv1d processing. auto dtype = inputs.dtype(); + // Build-once HIP-graph decode: keep the conv + SSM recurrent state in a SINGLE + // static buffer updated IN PLACE (inplace_write). The recorded graph reads the + // fixed-address state, computes the new state, and writes it back into the same + // buffer — one kernel's read-before-write is safe within a launch, and the + // in-place write accumulates across relaunches. No double buffer / parity is + // needed (only stream capture required ping-pong; build+relaunch does not). + bool gdn_inplace = mlx_lm::graph_external_pos() && S == 1 && cache; + mx::array conv_state(0.0f); if (cache && (*cache)[0].has_value()) { conv_state = (*cache)[0].value(); @@ -410,7 +450,14 @@ mx::array Qwen35MoEGatedDeltaNet::operator()( if (S == 1 && cache) { auto [conv_out, new_state] = gdn_conv_step(conv_state, qkv, conv1d_weight_); - (*cache)[0] = new_state; + if (gdn_inplace && (*cache)[0].has_value()) { + // Read [0], write new state to scratch [2]; loop copies [2]->[0]. + if (!(*cache)[2].has_value()) + (*cache)[2] = mx::zeros_like((*cache)[0].value()); + (*cache)[2] = gdn_state_overwrite_(std::move((*cache)[2].value()), new_state); + } else { + (*cache)[0] = new_state; + } // Split into q, k, v auto q_out = mx::reshape(mx::slice(conv_out, {0, 0, 0}, {B, 1, key_dim_}), @@ -426,11 +473,20 @@ mx::array Qwen35MoEGatedDeltaNet::operator()( q_norm_w_ = mx::full({head_k_dim_}, inv_scale * inv_scale, dtype); k_norm_w_ = mx::full({head_k_dim_}, inv_scale, dtype); } - q_out = mx::fast::rms_norm(q_out, *q_norm_w_, 1e-6f); - k_out = mx::fast::rms_norm(k_out, *k_norm_w_, 1e-6f); + // GDN decode step. MLX_GDN_NO_FUSED=1 -> inline mx::compile recurrence. + // MLX_GDN_NO_FUSED2=1 -> the per-op fused path (rms_norm + beta/g + + // gated_delta_step). Default: the FlashQLA-style gdn_fused_decode kernel + // that folds q/k-RMSNorm + beta/g + recurrence into one launch. + static const bool use_fused_gdn = std::getenv("MLX_GDN_NO_FUSED") == nullptr; + const bool use_fused2 = + use_fused_gdn && std::getenv("MLX_GDN_NO_FUSED2") == nullptr; + + // The fused kernel folds the q/k norm; only normalize here otherwise. + if (!use_fused2) { + q_out = mx::fast::rms_norm(q_out, *q_norm_w_, 1e-6f); + k_out = mx::fast::rms_norm(k_out, *k_norm_w_, 1e-6f); + } - // GDN decode step via the fused HIP kernel (gated_delta_update). - // MLX_GDN_NO_FUSED=1 falls back to the inline mx::compile recurrence. mx::array ssm_state(0.0f); if ((*cache)[1].has_value()) { ssm_state = (*cache)[1].value(); @@ -438,12 +494,32 @@ mx::array Qwen35MoEGatedDeltaNet::operator()( ssm_state = mx::zeros({B, num_v_heads_, head_v_dim_, head_k_dim_}, dtype); } - static const bool use_fused_gdn = std::getenv("MLX_GDN_NO_FUSED") == nullptr; + static const bool st_ck = std::getenv("MLX_STATE_CKSUM") != nullptr; + if (st_ck) { + auto c = mx::sum(mx::abs(mx::astype(ssm_state, mx::float32))); + mx::eval(c); + fprintf(stderr, "[st] read_ssm %.6e\n", c.item()); + } + if (use_fused_gdn) { - auto [o, ns] = gated_delta_update( - q_out, k_out, v_out, a_val, b_val, a_log_, dt_bias_, ssm_state, - std::nullopt, /*inplace=*/false); - (*cache)[1] = ns; + mx::array o(0.0f), ns(0.0f); + if (use_fused2) { + std::tie(o, ns) = gdn_fused_decode( + q_out, k_out, v_out, a_val, b_val, a_log_, dt_bias_, + *q_norm_w_, *k_norm_w_, ssm_state); + } else { + std::tie(o, ns) = gated_delta_update( + q_out, k_out, v_out, a_val, b_val, a_log_, dt_bias_, + ssm_state, std::nullopt, /*inplace=*/false); + } + if (gdn_inplace && (*cache)[1].has_value()) { + // Read [1], write new state to scratch [3]; loop copies [3]->[1]. + if (!(*cache)[3].has_value()) + (*cache)[3] = mx::zeros_like((*cache)[1].value()); + (*cache)[3] = gdn_state_overwrite_(std::move((*cache)[3].value()), ns); + } else { + (*cache)[1] = ns; + } auto normalized = norm_(o, z); return linear_fwd(mx::reshape(normalized, {B, S, -1}), out_proj_weight_); } @@ -786,10 +862,17 @@ mx::array Qwen35MoEModelInner::operator()(const mx::array& inputs, std::vector ssm_mask; + static const bool cksum = std::getenv("MLX_PURE_CKSUM") != nullptr; for (size_t i = 0; i < layers_.size(); ++i) { KVCache* lc = (cache && i < cache->size()) ? &(*cache)[i] : nullptr; auto attn_mask = layers_[i].is_linear() ? AttentionMask{} : fa_mask; h = layers_[i](h, attn_mask, ssm_mask, lc); + if (cksum) { + auto c = mx::sum(mx::abs(mx::astype(h, mx::float32))); + mx::eval(c); + fprintf(stderr, "[ck] L%02zu %s %.7e\n", i, + layers_[i].is_linear() ? "gdn " : "attn", c.item()); + } } return mx::fast::rms_norm(h, norm_weight_, rms_norm_eps_); @@ -867,10 +950,17 @@ mx::array Qwen35MoEModelInner::forward_prenorm(const mx::array& inputs, std::vec std::optional ssm_mask; + static const bool cksum = std::getenv("MLX_PURE_CKSUM") != nullptr; for (size_t i = 0; i < layers_.size(); ++i) { KVCache* lc = (cache && i < cache->size()) ? &(*cache)[i] : nullptr; auto attn_mask = layers_[i].is_linear() ? AttentionMask{} : fa_mask; h = layers_[i](h, attn_mask, ssm_mask, lc); + if (cksum) { + auto c = mx::sum(mx::abs(mx::astype(h, mx::float32))); + mx::eval(c); + fprintf(stderr, "[ck] L%02zu %s %.7e\n", i, + layers_[i].is_linear() ? "gdn " : "attn", c.item()); + } } return h; // pre-norm diff --git a/src/llm/models/qwen3_next.cpp b/src/llm/models/qwen3_next.cpp index dfb0819d..bca1eb53 100644 --- a/src/llm/models/qwen3_next.cpp +++ b/src/llm/models/qwen3_next.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -76,12 +77,11 @@ Qwen3NextRMSNormGated::Qwen3NextRMSNormGated(int dimensions, float eps) mx::array Qwen3NextRMSNormGated::operator()(const mx::array& x, const std::optional& gate) { - auto result = mx::fast::rms_norm(x, weight_, eps_); if (gate.has_value()) { - // Swift: silu(gate) * result - result = swiglu(*gate, result); + // silu(gate) * rmsnorm(x) * weight, fused into one kernel on ROCm. + return gated_rms_norm(x, *gate, weight_, eps_); } - return result; + return mx::fast::rms_norm(x, weight_, eps_); } std::unordered_map Qwen3NextRMSNormGated::weight_map() {