Skip to content

Build-once HIP-graph decode replay (coherent, gated)#45

Open
Geramy wants to merge 5 commits into
mainfrom
geramy/graph_replay_fix
Open

Build-once HIP-graph decode replay (coherent, gated)#45
Geramy wants to merge 5 commits into
mainfrom
geramy/graph_replay_fix

Conversation

@Geramy

@Geramy Geramy commented Jun 27, 2026

Copy link
Copy Markdown
Member

Summary

Makes the pure-graph decode path a true build-once / replay: capture the entire single-token forward into one HIP graph once, then relaunch the cached exec every token instead of rebuilding. Output is byte-identical to eager on both gfx1151 (APU) and gfx1201 (R9700).

Requires the matching rocm-support MLX commit (d3a4d4a9): decode_capture_begin/end_record/replay, the worker host-func capture guard, fine-grained decode arena + replay floor.

What changed

  • step_pure_graph: record state captures call_fn into one exec (async_eval, no blocking sync mid-capture); replay state relaunches it and samples the overwritten logits buffer. Arena reset-to-floor preserves the recorded buffers while per-token sampling allocates above them.
  • pure_logits_ holds the recorded logits array; the relaunch overwrites its buffer and convert_to_token reads it fresh.
  • GDN recurrent state ping-pongs through scratch [2]/[3] with a per-token copy back to [0]/[1] (required for correctness: in-place [0]/[1] inside the captured graph does not hold the fixed address the relaunch bakes).
  • qwen35_moe: cache the all-zeros slice_update start index in gdn_state_overwrite_ (previously a redundant memset node per GDN layer per token).

Status / perf

Gated by MLX_DECODE_GRAPH_PURE (default off); the default decode path is unchanged.

Profiling (rocprof, steady-state decode) shows this 35B-A3B decode is ~83% GPU-busy and matmul-dominated (qmv_wide + gather_qmv_wide ≈ 70%), with only ~15-17% idle. On ROCm the captured graph is a single-stream linear chain and the executor inserts a barrier/cache-flush per dependency edge, so replay currently lands ~0.85-0.96x of eager rather than faster — the dispatch gap the graph would fill is offset by per-edge barrier cost. The mechanism is correct and preserved here; the real decode lever for this model is fewer kernels (fusion) + faster matmul.

Geramy added 5 commits June 25, 2026 19:16
Rework the pure-graph decode path to capture the whole forward into one HIP
graph once and relaunch the cached exec every token, instead of rebuilding
per token. Output is byte-identical to eager on gfx1151 and gfx1201.

- step_pure_graph: record state captures call_fn into one exec (async_eval
  so there is no blocking sync mid-capture); replay state relaunches it and
  samples the overwritten logits buffer; arena reset-to-floor preserves the
  recorded buffers while sampling allocates above them.
- pure_logits_ holds the recorded logits array; the relaunch overwrites its
  buffer and convert_to_token reads it fresh.
- GDN recurrent state ping-pongs through scratch [2]/[3] with a per-token
  copy back to [0]/[1] (required: in-place [0]/[1] in the captured graph does
  not hold the fixed address the relaunch bakes).
- qwen35_moe: cache the all-zeros slice_update start index in
  gdn_state_overwrite_ (was a redundant memset node per GDN layer per token).

Gated by MLX_DECODE_GRAPH_PURE (default off). On ROCm the captured graph is a
linear chain and the executor adds a barrier per edge, so replay does not yet
beat eager for this matmul-bound 35B decode.
Add an optional ping-pong-parity replay mode (MLX_PURE_PARITY, default off):
record two captured execs — parity 0 reads state [0]/[1] writes scratch [2]/[3],
parity 1 reads [2]/[3] writes [0]/[1] — and relaunch by token parity, so the
write of one relaunch is the read of the next. This removes the per-token
[2]->[0] state copy the single-graph path needs. Both read/write distinct slots
(no in-place aliasing), so addresses stay fixed for the baked graphs. Coherent
and byte-identical to eager on gfx1151/gfx1201.

- graph_decode: add graph_decode_parity()/set_graph_decode_parity().
- qwen35_moe: GDN read/write slots are parity-driven (parity 0 == prior
  single-graph behavior). Cache the all-zeros slice_update start index in
  gdn_state_overwrite_ (was a redundant memset node per GDN layer per token).
- step_pure_graph: state machine records one exec (single) or two (parity),
  alternating replay; single-graph still copies [2]->[0].

Note: parity is TPS-neutral here (~same as single-graph) — it removes the copies
but the dominant cost is the ROCm graph executor's per-edge barriers, not the
copies. Kept as a cleaner, validated zero-copy option; single-graph stays default.
…9700)

Replace the GDN ping-pong scratch + the KV slice_update copy with true in-place
writes, so the build-once decode graph carries no per-token state/KV copies.

- GDN recurrent state: the fused kernels (gdn_fused_decode, gdn_conv_step) gain
  an in-place variant whose state output ALIASES the state input (forced alias,
  like kv_inplace_update). The model updates cache slots [0]/[1] in place — no
  scratch [2]/[3], no parity ping-pong, no copy. The kernels read the full state
  before writing, so the alias is race-free, and (unlike slice_update donation)
  the aliased buffer's address is preserved under replay.
- KV: add kv_inplace_update_at(cache, new_kv, pos) — the existing in-place
  accessor but reading the write position from the device pos buffer. update_at_pos
  uses it, so the math output (roped K/V) is written directly into KV[pos] with no
  slice_update array op.
- Remove the now-obsolete machinery and stale/contradictory comments: ping-pong
  parity, the two-parity mode, the [2]->[0] engine copy loop, gdn_state_overwrite_,
  graph_decode_parity, pure_step_.

Captured decode graph drops from ~1313 nodes (+80 memsets) to ~1033 nodes (0
memsets). Coherent and byte-identical to eager on gfx1151 and gfx1201.

TPS: gfx1201 (R9700) pure replay 72 vs eager 63 (+15%, now BEATS eager — this is
the launch-bound target); gfx1151 (APU) pure 57 vs eager 62 (compute-bound, still
under but up from 52.5). Gated by MLX_DECODE_GRAPH_PURE.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant