Skip to content

Build-once HIP-graph decode (device-position) + graphs-off-by-default#42

Merged
Geramy merged 11 commits into
mainfrom
geramy/graph-prefill
Jun 26, 2026
Merged

Build-once HIP-graph decode (device-position) + graphs-off-by-default#42
Geramy merged 11 commits into
mainfrom
geramy/graph-prefill

Conversation

@Geramy

@Geramy Geramy commented Jun 26, 2026

Copy link
Copy Markdown
Member

Build-once / pure-relaunch HIP-graph decode for the Qwen3.6 hybrid (GatedDeltaNet + attention + MoE), plus defaulting HIP graphs off (eager) since the rebuild-per-token path is a net loss on the integrated APU.

Highlights

  • Graphs off by default (opt in with MLX_USE_HIP_GRAPHS=1). Eager ~65 TPS beats the rebuild-per-token graph path ~49 on gfx1151.
  • Device-position decode (opt-in MLX_DECODE_GRAPH_PURE): RoPE/KV-write/mask read a fixed-address device position buffer; KV update_at_pos + reserve_to; GDN recurrent state kept in a static buffer with decoupled write-back (scratch→state between relaunches). One graph, build-once, relaunch — output verified bit-matching eager.
  • Deterministic decode arena + pure record/replay in the ROCm backend (no SetParams/ExecUpdate).

Status

  • Default path (eager) verified coherent on 35B-A3B MoE and 27B dense.
  • Build-once replay is correct but currently re-runs the model before relaunch (slower); relaunch-only is the next step.

Geramy added 11 commits June 25, 2026 07:13
- gdn_fused_decode: fold q/k-RMSNorm + beta/g + the delta recurrence into one
  kernel (column/coalesced layout, unrolled loops); wired into the GDN T=1
  decode fast path. +18% decode on gfx1151. Gated MLX_GDN_NO_FUSED2.
- gated_rms_norm: silu(gate)*rmsnorm(x)*weight in a single kernel (was
  rms_norm + swiglu); wired into Qwen3NextRMSNormGated. Gated MLX_FUSED_NORM_MXOPS.
- kv_inplace_update: write new K/V into the KV-cache slice in place (kernel
  output aliased to the cache buffer) instead of slice_update, whose COW
  donation fails under the async one-behind pipeline and copies the whole
  cache — a variable per-token copy count. This stabilizes the decode-graph
  topology (constant node count per token). Gated MLX_KV_INPLACE_OFF.

All default-on, coherent over long greedy runs (gfx1151, --device 0).
Re-introduce graph_decode.{h,cpp}: a fixed-address [1] int32 device position
buffer, advanced IN PLACE between replays (loop-owned advance, never inside the
captured graph — an in-graph pos++ races the RoPE/mask readers). set/advance call
the backend gpu_kv_pos_* kernels; graph_capturing() gates capture-vs-eager;
graph_decode_enabled() opt-in via MLX_DECODE_GRAPH. Foundation for documented
capture-once + pure-relaunch decode (Phase C proved bit-exact); dormant until the
attention/GDN device-position wiring + capture path land.
KVCacheSimple::update_at_pos(k, v, pos) writes the new token at the device-side
slot `pos` via DynamicSliceUpdate (slice_update axis 2), returning the full
pre-allocated buffer. The offset advances device-side so the built decode graph
relaunches correctly as the loop advances the position. Dormant until the model
device-position attention + engine relaunch loop are wired.
When the engine drives a built decode graph (graph_external_pos, or the
MLX_DECODE_DEVICE_POS opt-in for standalone testing), the L==1 attention reads
the fixed-address [1] int32 device position buffer for RoPE offset, the KV write
slot (update_at_pos), and the causal mask (cols <= pos). The same graph then
relaunches correctly as the loop advances pos device-side — no per-token
SetParams/re-pointing. KVCache/CompoundCache gain update_at_pos dispatch
(std::visit + if-constexpr; throws for unsupported cache types). Default path
(gmode off) is unchanged and verified coherent.
…ode)

In graph mode (graph_external_pos, S==1) the GatedDeltaNet conv + SSM state
becomes a [2, …] device-parity ping-pong buffer: read slot pos&1, write slot
(pos+1)&1 via dynamic slice / slice_update. The built graph relaunches with an
advancing device pos, so each replay reads the previous replay's write and
accumulates — an in-place RMW at a fixed address would not accumulate across
relaunches. Prefill's single-buffer state is promoted (both slots seeded) on
the first graph step; parity is snapshotted outside capture to avoid the async
pipeline reading the wrong slot. Default path (gdn_dbuf off) is byte-identical
and verified coherent.
Allocates an identical multi-op sequence across two arena resets and asserts the
output buffer lands at the same device address each time (p1==p2==p3, equal
high-water, no overflow). Verified on gfx1151: DETERMINISTIC OK.
…GRAPH_PURE)

Wires the full build-once graph-decode path, off by default:
- graph_decode: fixed-address [1,1] input-token buffer + device-copy feed.
- KVCacheSimple::reserve_to: pre-grow KV to capacity so device-offset writes
  never realloc; update_at_pos + GDN dbuf writes std::move the old buffer so
  slice_update DONATES (verified in-place via test_donate — no full copy).
- TokenIterator::step_pure_graph: warmup -> record -> replay state machine.
  Per token (eager, between relaunches): feed input + advance device pos, then
  rewind the arena and relaunch the recorded chain.
- Diagnostics: MLX_PURE_NOREPLAY / _FREEZE / _FREEZE_POS / _FREEZE_INPUT and
  MLX_GDN_NO_DBUF for isolating the path.

Validated so far on gfx1151 (Qwen3.6-A3B q4):
- Device-position decode (NOREPLAY) is bit-correct vs eager — attention AND GDN
  double-buffer (snapshot parity path) reproduce eager output.
- The recorded graph is faithful: frozen pos+input reproduces the record token's
  output on every relaunch.
- Remaining bug: under live relaunch the first replayed token diverges by a small
  numeric perturbation (adjacent argmax) that compounds — injection propagation /
  a subtle stale state read; needs per-op checksum diffing to pin down.
…write-back)

The recorded decode graph reads GDN state [0]/[1] and writes the new state to
scratch [2]/[3] (separate fixed buffers — no read==write hazard on relaunch);
the decode loop copies [2]->[0], [3]->[1] between relaunches (loop-owned, like
the position advance). MambaCache gains scratch slots [2]/[3]. Single graph, no
parity. Verified on gfx1151: pure replay output matches eager.

Known: replay still re-runs the model to build the discarded tape, so TPS is
below eager — the speedup needs skipping the re-run (relaunch-only) next.
@Geramy Geramy merged commit c62dcd5 into main Jun 26, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant