Build-once HIP-graph decode replay (coherent, gated)#45
Open
Geramy wants to merge 5 commits into
Open
Conversation
Rework the pure-graph decode path to capture the whole forward into one HIP graph once and relaunch the cached exec every token, instead of rebuilding per token. Output is byte-identical to eager on gfx1151 and gfx1201. - step_pure_graph: record state captures call_fn into one exec (async_eval so there is no blocking sync mid-capture); replay state relaunches it and samples the overwritten logits buffer; arena reset-to-floor preserves the recorded buffers while sampling allocates above them. - pure_logits_ holds the recorded logits array; the relaunch overwrites its buffer and convert_to_token reads it fresh. - GDN recurrent state ping-pongs through scratch [2]/[3] with a per-token copy back to [0]/[1] (required: in-place [0]/[1] in the captured graph does not hold the fixed address the relaunch bakes). - qwen35_moe: cache the all-zeros slice_update start index in gdn_state_overwrite_ (was a redundant memset node per GDN layer per token). Gated by MLX_DECODE_GRAPH_PURE (default off). On ROCm the captured graph is a linear chain and the executor adds a barrier per edge, so replay does not yet beat eager for this matmul-bound 35B decode.
Add an optional ping-pong-parity replay mode (MLX_PURE_PARITY, default off): record two captured execs — parity 0 reads state [0]/[1] writes scratch [2]/[3], parity 1 reads [2]/[3] writes [0]/[1] — and relaunch by token parity, so the write of one relaunch is the read of the next. This removes the per-token [2]->[0] state copy the single-graph path needs. Both read/write distinct slots (no in-place aliasing), so addresses stay fixed for the baked graphs. Coherent and byte-identical to eager on gfx1151/gfx1201. - graph_decode: add graph_decode_parity()/set_graph_decode_parity(). - qwen35_moe: GDN read/write slots are parity-driven (parity 0 == prior single-graph behavior). Cache the all-zeros slice_update start index in gdn_state_overwrite_ (was a redundant memset node per GDN layer per token). - step_pure_graph: state machine records one exec (single) or two (parity), alternating replay; single-graph still copies [2]->[0]. Note: parity is TPS-neutral here (~same as single-graph) — it removes the copies but the dominant cost is the ROCm graph executor's per-edge barriers, not the copies. Kept as a cleaner, validated zero-copy option; single-graph stays default.
…9700) Replace the GDN ping-pong scratch + the KV slice_update copy with true in-place writes, so the build-once decode graph carries no per-token state/KV copies. - GDN recurrent state: the fused kernels (gdn_fused_decode, gdn_conv_step) gain an in-place variant whose state output ALIASES the state input (forced alias, like kv_inplace_update). The model updates cache slots [0]/[1] in place — no scratch [2]/[3], no parity ping-pong, no copy. The kernels read the full state before writing, so the alias is race-free, and (unlike slice_update donation) the aliased buffer's address is preserved under replay. - KV: add kv_inplace_update_at(cache, new_kv, pos) — the existing in-place accessor but reading the write position from the device pos buffer. update_at_pos uses it, so the math output (roped K/V) is written directly into KV[pos] with no slice_update array op. - Remove the now-obsolete machinery and stale/contradictory comments: ping-pong parity, the two-parity mode, the [2]->[0] engine copy loop, gdn_state_overwrite_, graph_decode_parity, pure_step_. Captured decode graph drops from ~1313 nodes (+80 memsets) to ~1033 nodes (0 memsets). Coherent and byte-identical to eager on gfx1151 and gfx1201. TPS: gfx1201 (R9700) pure replay 72 vs eager 63 (+15%, now BEATS eager — this is the launch-bound target); gfx1151 (APU) pure 57 vs eager 62 (compute-bound, still under but up from 52.5). Gated by MLX_DECODE_GRAPH_PURE.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Makes the pure-graph decode path a true build-once / replay: capture the entire single-token forward into one HIP graph once, then relaunch the cached exec every token instead of rebuilding. Output is byte-identical to eager on both gfx1151 (APU) and gfx1201 (R9700).
Requires the matching rocm-support MLX commit (
d3a4d4a9):decode_capture_begin/end_record/replay, the worker host-func capture guard, fine-grained decode arena + replay floor.What changed
step_pure_graph: record state capturescall_fninto one exec (async_eval, no blocking sync mid-capture); replay state relaunches it and samples the overwritten logits buffer. Arena reset-to-floor preserves the recorded buffers while per-token sampling allocates above them.pure_logits_holds the recorded logits array; the relaunch overwrites its buffer andconvert_to_tokenreads it fresh.[2]/[3]with a per-token copy back to[0]/[1](required for correctness: in-place[0]/[1]inside the captured graph does not hold the fixed address the relaunch bakes).qwen35_moe: cache the all-zerosslice_updatestart index ingdn_state_overwrite_(previously a redundant memset node per GDN layer per token).Status / perf
Gated by
MLX_DECODE_GRAPH_PURE(default off); the default decode path is unchanged.Profiling (rocprof, steady-state decode) shows this 35B-A3B decode is ~83% GPU-busy and matmul-dominated (
qmv_wide+gather_qmv_wide≈ 70%), with only ~15-17% idle. On ROCm the captured graph is a single-stream linear chain and the executor inserts a barrier/cache-flush per dependency edge, so replay currently lands ~0.85-0.96x of eager rather than faster — the dispatch gap the graph would fill is offset by per-edge barrier cost. The mechanism is correct and preserved here; the real decode lever for this model is fewer kernels (fusion) + faster matmul.