Skip to content

MoE router route buffers are unstable when produced and consumed inside one JIT kernel on NPU #255

@zhaozhaozz

Description

@zhaozhaozz

Summary

The DeepSeek V4 MoE path is intended to run as:

hc_pre -> router -> dispatch -> expert -> combine

The same path passes on local a2a3sim, but fails on remote a2a3 NPU in the full MoE kernel. The observed full-path result is:

  • post_ffn: PASS
  • comb_ffn: PASS
  • ffn_out: FAIL
  • typical ffn_out max_abs_diff: 2.281

We tested two dispatch workarounds:

  • Approach A: scalar pl.write route table writes. This is closer to ideal data-dependent scatter, but route table writes are not reliably visible to later inline stages on NPU.
  • Approach B: broadcast + static-mask dispatch. This avoids data-dependent writes by using fixed-offset pl.assemble, but route buffers still become incorrect when the route inputs are small tensors produced by the router inside the same JIT kernel.

The current best description is: small route tensors produced inside one JIT kernel by the router (indices/weights, or padded variants) are expanded by dispatch into recv_weights/route_mask; when those route buffers are then consumed internally or exposed as top-level outputs, NPU backend behavior appears sensitive to data versioning, manual scope dependency, lifetime, or liveness.

Background

Ideal MoE dispatch uses the runtime router result to decide the destination expert and slot:

for t in range(T):
    for k in range(TOPK):
        e = indices[t, k]
        p = slot[e]
        recv_x[e, p, :] = x_norm[t, :]
        recv_weights[e, p] = weights[t, k]
        route_mask[e, p] = 1.0
        slot[e] += 1

The current Python DSL cannot directly express this data-dependent tensor write. pl.assemble requires a compile-time literal destination offset, so it cannot directly implement recv_weights[indices[t, k], slot] = weights[t, k].

Verified approaches

Approach Idea Expected Actual Current conclusion
A. scalar pl.write route table writes Compute a data-dependent flat index and write recv_weights/route_mask with scalar pl.write Later inline/stage reads see the scalar writes Works in sim; unreliable on NPU across inline boundaries Need to define and fix, or reject, cross-inline pl.write visibility
B. broadcast + static-mask dispatch Broadcast recv_x to all experts and build recv_weights/route_mask with fixed-offset pl.assemble Avoid data-dependent writes; non-routed slots are masked out Host synthetic routes pass; router-produced route-only snapshot fails; full stage passes only when route outputs are exposed Current main issue is dependency/lifetime for router-produced small route tensors and fixed-offset route buffers

Approach A: scalar pl.write route table writes

Approach

The large BF16 recv_x buffer is still broadcast with fixed offsets. The small FP32 route tables are written with scalar pl.write at a data-dependent flat index:

dst = indices[t, k] * RECV_MAX + k * T + t
recv_weights_flat[dst] = weights[t, k]
route_mask_flat[dst] = 1.0

Key code

@pl.jit.inline
def scalar_write_route_dispatch(
    x_norm,         # [T, D]
    indices,        # [T, TOPK]
    weights,        # [T, TOPK]
    recv_x,         # [E, RECV_MAX, D]
    recv_weights,   # [E, RECV_MAX]
    route_mask,     # [E, RECV_MAX]
):
    recv_x_flat = pl.reshape(recv_x, [N_LOCAL_EXPERTS * RECV_MAX, D])
    recv_weights_flat = pl.reshape(recv_weights, [N_LOCAL_EXPERTS * RECV_MAX])
    route_mask_flat = pl.reshape(route_mask, [N_LOCAL_EXPERTS * RECV_MAX])

    # Large tensor path: fixed-offset broadcast.
    for e in pl.static_range(N_LOCAL_EXPERTS):
        for k in pl.static_range(TOPK):
            for d0 in pl.static_range(0, D, 64):
                x_tile = pl.read(x_norm, [0, d0], [T, 64])
                recv_x_flat = pl.assemble(
                    recv_x_flat,
                    x_tile,
                    [e * RECV_MAX + k * T, d0],
                )

    # Small route-table path: scalar writes with a data-dependent flat index.
    for t in pl.static_range(T):
        for k in pl.static_range(TOPK):
            expert_id = pl.read(indices, [t, k])
            weight = pl.read(weights, [t, k])
            dst = expert_id * RECV_MAX + k * T + t
            recv_weights_flat = pl.write(recv_weights_flat, [dst], weight)
            route_mask_flat = pl.write(route_mask_flat, [dst], 1.0)

    recv_weights = pl.reshape(recv_weights_flat, [N_LOCAL_EXPERTS, RECV_MAX])
    route_mask = pl.reshape(route_mask_flat, [N_LOCAL_EXPERTS, RECV_MAX])
    recv_x = pl.reshape(recv_x_flat, [N_LOCAL_EXPERTS, RECV_MAX, D])
    return recv_x, recv_weights, route_mask

The consumer must read recv_weights/route_mask from a later inline/stage:

@pl.jit
def scalar_write_then_consume_case(..., recv_weights_out, route_mask_out, ffn_out):
    x_norm, indices, weights = run_router(...)
    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = scalar_write_route_dispatch(
        x_norm, indices, weights, recv_x, recv_weights, route_mask,
    )

    recv_weights_out, route_mask_out = copy_route_snapshot(recv_weights, route_mask)
    recv_y, sh = moe_expert(recv_x, recv_weights, x_norm, expert_weights, shared_weights)
    combine(recv_y, route_mask, sh, ffn_out)
    return recv_weights_out, route_mask_out, ffn_out

Expected

Both sim and NPU should observe scalar pl.write updates in later inline/stage reads. Since @pl.jit.inline is still inside the same JIT kernel, users would expect a tensor write in one inline block to be visible to a later inline block.

Actual

Sim observes the writes in execution order. NPU later reads are unreliable and may observe old or unwritten values.

Analysis

This should be treated as a compiler/codegen bug or an undefined semantic gap, not as a normal design constraint. If scalar pl.write is not guaranteed to be visible across inline boundaries, it should be documented and diagnosed at compile time instead of silently producing incorrect NPU results.

Approach B: broadcast + static-mask dispatch

Approach

Approach B is the current main path. It avoids scalar pl.write by making all destination offsets compile-time constants:

  • recv_x[e, k*T+t, :] is a fixed-offset broadcast of x_norm[t, :].
  • recv_weights[e, k*T+t] and route_mask[e, k*T+t] are generated with an arithmetic equality mask from indices[t, k] == e.
  • Non-routed slots are zeroed by recv_weights=0 and route_mask=0.

The issue is that when indices/weights are small tensors produced by the router inside the same JIT kernel, NPU route buffer snapshots become incorrect, even though a host-synthetic route control case passes.

Key code: fixed-offset dispatch

T = 16
D = 4096
TOPK = 2
N_LOCAL_EXPERTS = 8
RECV_MAX = T * TOPK

@pl.jit.inline
def broadcast_static_mask_dispatch(
    x_norm,        # [T, D]
    indices,       # [T, TOPK], router output
    weights,       # [T, TOPK], router output
    recv_x_flat,   # [N_LOCAL_EXPERTS * RECV_MAX, D]
    recv_weights,  # [N_LOCAL_EXPERTS, RECV_MAX]
    route_mask,    # [N_LOCAL_EXPERTS, RECV_MAX]
    expert_table,  # [N_LOCAL_EXPERTS, RECV_MAX], host-filled expert id table
):
    for e in pl.parallel(N_LOCAL_EXPERTS):
        for k in pl.static_range(TOPK):
            for d0 in pl.static_range(0, D, 64):
                x_tile = pl.read(x_norm, [0, d0], [T, 64])
                recv_x_flat = pl.assemble(
                    recv_x_flat,
                    x_tile,
                    [e * RECV_MAX + k * T, d0],
                )

            idx = pl.read(indices, [0, k], [T, 1])
            w = pl.read(weights, [0, k], [T, 1])
            expert_id = pl.read(expert_table, [e, k * T], [1, T])
            expert_id = pl.reshape(expert_id, [T, 1])

            # mask = 1 - min(abs(idx - expert_id), 1)
            diff = pl.abs(pl.sub(idx, expert_id))
            clipped = pl.minimum(diff, pl.full((T, 1), value=1.0, dtype=pl.FP32))
            mask = pl.sub(pl.full((T, 1), value=1.0, dtype=pl.FP32), clipped)
            masked_weight = pl.mul(w, mask)

            recv_weights = pl.assemble(recv_weights, masked_weight, [e, k * T])
            route_mask = pl.assemble(route_mask, mask, [e, k * T])

    return recv_x_flat, recv_weights, route_mask

Key code: route-only minimal repro

This removes real expert, combine, and fake readers. It keeps only:

hc_pre -> router -> dispatch -> route snapshot
@pl.jit
def route_snapshot_case(
    x,
    norm_w,
    gate_w,
    gate_b,
    recv_x_flat,
    recv_weights,
    route_mask,
    expert_table,
):
    x_norm = hc_pre(x, norm_w)

    # router_padded returns x_norm, indices_pad, and weights_pad.
    # indices_pad/weights_pad are small route tensors produced inside this JIT.
    x_norm, indices_pad, weights_pad = router_padded(
        x_norm,
        gate_w,
        gate_b,
    )

    recv_x_flat, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm,
        indices_pad,
        weights_pad,
        recv_x_flat,
        recv_weights,
        route_mask,
        expert_table,
    )

    return x_norm, recv_weights, route_mask

Expected CPU golden:

expected_weights = zeros([N_LOCAL_EXPERTS, RECV_MAX], dtype=float32)
expected_mask = zeros([N_LOCAL_EXPERTS, RECV_MAX], dtype=float32)

for t in range(T):
    for k in range(TOPK):
        e = int(indices[t, k])
        slot = k * T + t
        expected_weights[e, slot] = float(weights[t, k])
        expected_mask[e, slot] = 1.0

Expected outputs:

  • x_norm matches CPU golden.
  • recv_weights == expected_weights.
  • route_mask == expected_mask.

Approach B variants tested

Variant Change Expected Actual
padded route Router produces padded indices_pad/weights_pad; dispatch reads padded route tables PASS x_norm PASS, recv_weights/route_mask FAIL
compact-touch Dispatch reads compact indices/weights; router side adds explicit read/touch to try to preserve route tensors Improves if padded small-table transfer is the trigger FAIL; touch variables may still be reported unused
padded-return Dispatch explicitly returns recv_weights/route_mask instead of relying only on InOut side effects Improves if missing return dependency is the trigger FAIL

Approach B layered tests

The following cases are not separate implementation approaches. They are layered repros for Approach B. Each case changes one key variable: route source, x_norm exposure, route buffer exposure, and downstream consumer type.

C0: synthetic route input + real internal consumer

This bypasses hc_pre/router. x_norm/indices/weights are host/CPU-provided JIT inputs. Dispatch-produced recv_weights/route_mask are consumed internally by real moe_expert and combine, while route snapshots are copied before and after the consumer for observation.

@pl.jit
def case_synthetic_route_real_consumer(
    x_norm, indices, weights, expert_weights, shared_weights, expert_table,
    recv_weights_pre, route_mask_pre, recv_weights_post, route_mask_post,
    recv_y_out, sh_out, ffn_out,
):
    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm, indices, weights, recv_x, recv_weights, route_mask, expert_table,
    )

    recv_weights_pre, route_mask_pre = copy_route_snapshot(recv_weights, route_mask)
    moe_expert(recv_x, recv_weights, x_norm, expert_weights, shared_weights, recv_y_out, sh_out)
    recv_weights_post, route_mask_post = copy_route_snapshot(recv_weights, route_mask)
    combine(recv_y_out, route_mask, sh_out, ffn_out)
    return recv_weights_pre, route_mask_pre, recv_weights_post, route_mask_post, recv_y_out, sh_out, ffn_out

Purpose: establish that dispatch, real expert, combine, and internal route buffer consumption are not unconditionally broken when route inputs come from host tensors.

C1: real router + fake consumer + top-level x_norm

The real hc_pre/router produces x_norm/indices_pad/weights_pad. The router writes x_norm directly to a top-level output, and dispatch uses that same output. The downstream consumer is a lightweight route-table reader, not real expert.

@pl.jit
def case_router_fake_consumer_top_level_xnorm(..., x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post):
    x_mixed = hc_pre(...)
    indices_pad, weights_pad = create_route_tensors()
    router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm_out, indices_pad, weights_pad)

    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm_out, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
    )

    recv_weights_pre, route_mask_pre = copy_route_snapshot(recv_weights, route_mask)
    fake_weighted_consumer(recv_x, recv_weights, route_mask, fake_out)
    recv_weights_post, route_mask_post = copy_route_snapshot(recv_weights, route_mask)
    return x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post

The fake consumer only reads recv_x/recv_weights/route_mask and accumulates weighted values:

def fake_weighted_consumer(recv_x, recv_weights, route_mask, fake_out):
    for d0 in chunks(D):
        accum = zeros([T, chunk])
        for e in range(N_LOCAL_EXPERTS):
            for k in range(TOPK):
                row0 = e * RECV_MAX + k * T
                x_blk = recv_x[row0 : row0 + T, d0 : d0 + chunk]
                w = recv_weights[row0 : row0 + T]
                m = route_mask[row0 : row0 + T]
                accum += x_blk * w[:, None] * m[:, None]
        fake_out[:, d0 : d0 + chunk] = accum

Purpose: check whether a lightweight downstream reader is enough to preserve/read route buffers. Actual results show fake_out can PASS while route snapshots already FAIL, so this consumer is not a sufficient correctness check.

C2: real router + fake consumer + internal x_norm

Same as C1, except router writes x_norm to an internal tensor. x_norm_out is only a copied observation output; dispatch uses the internal x_norm.

@pl.jit
def case_router_fake_consumer_internal_xnorm(..., x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post):
    x_mixed = hc_pre(...)
    x_norm = pl.create_tensor([T, D], dtype=pl.BF16)
    indices_pad, weights_pad = create_route_tensors()
    router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm, indices_pad, weights_pad)
    copy_x_norm(x_norm, x_norm_out)

    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
    )

    recv_weights_pre, route_mask_pre = copy_route_snapshot(recv_weights, route_mask)
    fake_weighted_consumer(recv_x, recv_weights, route_mask, fake_out)
    recv_weights_post, route_mask_post = copy_route_snapshot(recv_weights, route_mask)
    return x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post

Purpose: isolate whether top-level vs internal x_norm affects route table propagation.

C3: route-only padded

Keeps real hc_pre/router and padded route tables, but removes fake consumer, real expert, and combine. Dispatch immediately snapshots recv_weights/route_mask.

@pl.jit
def case_route_only_padded(..., x_norm_out, recv_weights_out, route_mask_out):
    x_mixed = hc_pre(...)
    x_norm = pl.create_tensor([T, D], dtype=pl.BF16)
    indices_pad, weights_pad = create_padded_route_tensors()
    router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm, indices_pad, weights_pad)
    copy_x_norm(x_norm, x_norm_out)

    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
    )
    recv_weights_out, route_mask_out = copy_route_snapshot(recv_weights, route_mask)
    return x_norm_out, recv_weights_out, route_mask_out

Purpose: minimal failure path. No real expert, combine, or fake reader is required for route snapshots to become wrong.

C4: route-only compact-touch

Same as C3, except dispatch reads compact indices/weights instead of padded indices_pad/weights_pad. The router side also reads/touches compact route tensors to try to preserve them.

@pl.jit
def case_route_only_compact_touch(...):
    x_mixed = hc_pre(...)
    x_norm = pl.create_tensor([T, D], dtype=pl.BF16)
    indices = pl.create_tensor([T, TOPK], dtype=pl.INT32)
    weights = pl.create_tensor([T, TOPK], dtype=pl.FP32)
    indices_pad, weights_pad = create_padded_route_tensors()
    router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm, indices, weights, indices_pad, weights_pad)
    copy_x_norm(x_norm, x_norm_out)

    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm, indices, weights, recv_x, recv_weights, route_mask, expert_table,
    )
    return x_norm_out, copy_route_snapshot(recv_weights, route_mask)

Purpose: rule out the explanation that only padded route tables fail, and check whether touch-to-pin creates a reliable dependency.

C5: route-only padded-return

Same as C3, except dispatch explicitly returns recv_weights/route_mask instead of relying only on InOut side effects.

@pl.jit.inline
def dispatch_return_route(x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table):
    recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
    )
    return recv_x, recv_weights, route_mask

@pl.jit
def case_route_only_padded_return(...):
    x_norm, indices_pad, weights_pad = run_router_padded(...)
    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = dispatch_return_route(
        x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
    )
    return x_norm_out, copy_route_snapshot(recv_weights, route_mask)

Purpose: rule out missing return dependency for route buffers as the only issue.

C6: full stage debug with route outputs exposed

Full chain, but key intermediate values are top-level outputs: x_norm, recv_x, recv_weights, route_mask, recv_y, sh, and ffn_out. recv_weights_out/route_mask_out are not just observations; they are directly consumed by real expert and combine.

@pl.jit
def case_full_stage_outputs(..., x_norm_out, recv_x_out, recv_weights_out, route_mask_out, recv_y_out, sh_out, ffn_out):
    x_mixed = hc_pre(...)
    indices_pad, weights_pad = create_padded_route_tensors()
    router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm_out, indices_pad, weights_pad)

    recv_x = create_recv_x()
    recv_x, recv_weights_out, route_mask_out = broadcast_static_mask_dispatch(
        x_norm_out, indices_pad, weights_pad, recv_x, recv_weights_out, route_mask_out, expert_table,
    )
    copy_recv_x(recv_x, recv_x_out)

    moe_expert(recv_x, recv_weights_out, x_norm_out, expert_weights, shared_weights, recv_y_out, sh_out)
    combine(recv_y_out, route_mask_out, sh_out, ffn_out)
    return ffn_out, x_norm_out, recv_x_out, recv_y_out, sh_out, recv_weights_out, route_mask_out

Purpose: test whether top-level output exposure changes route buffer liveness/lifetime. This case passes, which suggests the full math can work when route buffers are kept alive as outputs.

C7: full MoE without route outputs

Target shape. x_norm/recv_weights/route_mask/recv_x/recv_y/sh are internal tensors; the kernel returns only final MoE outputs and a few upstream debug outputs. Route buffers are materialized internally before real expert/combine.

@pl.jit
def case_full_moe_internal_route(..., ffn_out):
    x_mixed = hc_pre(...)
    x_norm, indices_pad, weights_pad = run_router_padded(...)

    recv_x, recv_weights, route_mask = create_dispatch_buffers()
    recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
        x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
    )

    recv_weights_use, route_mask_use = materialize_route_buffers(recv_weights, route_mask)
    recv_y, sh = moe_expert(recv_x, recv_weights_use, x_norm, expert_weights, shared_weights)
    combine(recv_y, route_mask_use, sh, ffn_out)
    return ffn_out

Purpose: target path. This fails while C6 passes, making route-buffer liveness/top-level-output behavior a key suspect.

Approach B results

ID Actual Key output Conclusion
C0 PASS recv_weights_pre/post, route_mask_pre/post, recv_y_out/sh_out/ffn_out PASS Host-synthetic route inputs can be consumed internally by real expert/combine.
C1 FAIL x_norm_out PASS, fake_out PASS, recv_weights_pre/post, route_mask_pre/post FAIL Fake consumer is not enough to prove route table correctness.
C2 FAIL x_norm_out/fake_out PASS, route snapshot FAIL Internal x_norm path also fails, but is not the only trigger.
C3 FAIL x_norm_out PASS, recv_weights_out/route_mask_out FAIL Minimal failure path; real expert/combine are not required.
C4 FAIL route snapshot FAIL Compact/flat read and touch-to-pin are not enough.
C5 FAIL route snapshot FAIL Explicitly returning route buffers is not enough.
C6 PASS ffn_out/post_ffn/comb_ffn/x_norm_out/recv_x_out/recv_y_out/sh_out/recv_weights_out/route_mask_out all PASS Full chain can pass when route buffers are top-level outputs.
C7 FAIL post_ffn/comb_ffn PASS, ffn_out FAIL, max_abs_diff=2.281 Target path still fails.

Typical route snapshot failure:

  • recv_weights has many zero-expected slots with actual=0.5.
  • route_mask has many zero-expected slots with actual=1.0.
  • Some real routed slots are expected nonzero but read as 0.

This looks like a data versioning/lifetime problem for route tensors or expanded route buffers, not a numerical precision issue.

Broader implementation options

Option 1: Data-dependent scatter

Plan: use indices[t, k] to directly choose recv_x/recv_weights/route_mask destinations and implement compact per-expert dispatch.

Expected: correct semantics, no non-routed slots, no mask overhead.

Actual: current PyPTO tensor APIs cannot express this directly. pl.assemble requires literal offsets, and Approach A's scalar pl.write workaround is unreliable across inline boundaries on NPU.

Open issue: provide reliable data-dependent scatter/gather APIs, or define and fix/reject scalar pl.write cross-inline visibility.

Option 2: Packed dispatch

Plan: store all (token, topk) pairs as [T*TOPK, D], with per-row expert id and route weight. This avoids data-dependent writes.

Expected: fixed write offsets, every row is valid, no mask.

Actual: not implemented end-to-end.

Open issue: expert kernels then need to select expert weights by per-row recv_expert_id, which is a data-dependent read source and is not currently expressible with static pl.slice style indexing.

Option 3: Broadcast + static-mask dispatch

Plan: current main path. Broadcast recv_x to all experts and generate recv_weights/route_mask with fixed-offset pl.assemble.

Expected: avoids data-dependent writes, while expert weight selection remains static by expert id.

Actual: sim passes, but NPU behavior depends on route source and top-level output exposure. Route-only snapshots fail when routes are produced by the router inside the same JIT.

Open issue: route tensor and route buffer dependency/lifetime across producer/consumer stages.

Option 4: All-experts broadcast

Plan: compute every expert for every token and combine with gate weights.

Expected: avoids dispatch and route-buffer issues.

Actual: works for small expert counts.

Open issue: computationally unacceptable for DeepSeek V4 scale.

Option 5: mscatter hybrid

Plan: keep fixed-offset broadcast for recv_x, but write recv_weights/route_mask with tile-level mscatter to GM.

Expected: closer to real scatter semantics and may avoid fixed-offset route-buffer lifetime issues.

Actual: no complete runnable implementation yet.

Open issue: no high-level API, index/source tile construction is complex, GM writes must be validated across later inline reads, aliasing must be avoided, dtype constraints apply.

Option 6: Borrow flat read / SSA style from sparse attention block_table

Plan: sparse attention successfully uses host-provided block_table with flat reads and local SSA-style assembly:

ori_block_table_flat = pl.reshape(ori_block_table, [B * ORI_MAX_BLOCKS])
cmp_block_table_flat = pl.reshape(cmp_block_table, [B * CMP_MAX_BLOCKS])

raw_idx = pl.read(cmp_sparse_indices_flat, [raw_idx_pos])
ori_slot = raw_idx // BLOCK_SIZE
ori_block_pos = ori_block_base + ori_slot
ori_blk = pl.read(ori_block_table_flat, [ori_block_pos])
ori_row = ori_blk * BLOCK_SIZE + ori_intra
kv_topk_batch = pl.assemble(kv_topk_batch, ori_kv_flat[ori_row : ori_row + 1, 0 : HEAD_DIM], [kk, 0])

This pattern was tested in MoE via compact route and explicit route-buffer return variants.

Expected: if the issue is only padded small-table transfer or missing InOut dependency, compact route or explicit return should help.

Actual: both variants fail. x_norm still passes, but recv_weights/route_mask snapshots fail. The compact-touch path may still report touch variables as unused.

Why sparse attention works but MoE does not directly inherit it:

  • Sparse attention block_table is a host readonly input; MoE indices/weights are small tensors dynamically produced by router inside the same JIT.
  • Sparse attention reads block_table and immediately assembles local scratch in the same logic; MoE passes small route tensors across router -> dispatch -> expert/combine.
  • Sparse attention does not need to produce mutable recv_weights/route_mask for later stages; MoE does.
  • MoE results change depending on whether route buffers are top-level outputs, which points to liveness or memory reuse interactions not present in the block_table path.

Open issue: flat read / explicit return is not sufficient. The part worth borrowing is the "produce local scratch and consume it immediately" pattern, which suggests the fused router-dispatch option below.

Option 7: Fuse router and dispatch route-table generation

Plan: fuse router top-k and dispatch route-table expansion into the same inline/stage, avoiding indices_pad/weights_pad as cross-inline small tensors.

Expected: shorter route tensor lifetime and fewer producer/consumer boundaries.

Actual: not implemented yet.

Open issue: if this passes, the current router/dispatch boundary is the trigger; if it fails, the issue points more directly at fixed-offset route-table outputs and consumer liveness.

Current assessment

  • Approach A shows scalar pl.write cross-inline visibility needs a defined semantic contract; current NPU behavior should not silently compute wrong results.
  • Approach B avoids scalar pl.write, but still fails when router-produced route tensors feed fixed-offset route buffers, so the issue is broader than the old pl.write workaround.
  • Host-synthetic route inputs pass, so it is not true that all internal route-buffer consumers are broken.
  • The full stage-debug path passes when route buffers are top-level outputs, while the target path without route outputs fails. Route-buffer liveness, memory reuse, or top-level-output scheduling behavior is likely involved.

Questions for PyPTO/PTOAS/backend owners

  1. Does PyPTO/PTOAS guarantee that an internal tensor written by fixed-offset pl.assemble in one inline/stage is visible with correct dependency and lifetime to later inline consumers?
  2. How are InOut side effects, inline return values, top-level outputs, manual scopes, and memory reuse modeled for small internal tensors such as indices/weights and recv_weights/route_mask?
  3. Why does the host-synthetic route input case pass, while the router-produced route-only snapshot fails?
  4. Why does the full stage-debug path pass when route buffers are top-level outputs, while the target path fails when route buffers are internal?
  5. Should scalar pl.write be visible across @pl.jit.inline boundaries inside the same JIT kernel? If yes, the current NPU behavior needs a fix. If no, the compiler should reject or warn instead of silently producing incorrect results.
  6. Should the DSL expose a reliable data-dependent scatter/gather primitive suitable for MoE dispatch, rather than requiring broadcast+mask workarounds?

Suggested issue attachments

  • Full remote a2a3 NPU logs or CI artifacts.
  • First mismatches for recv_weights/route_mask.
  • PASS/FAIL matrix for the control and failing cases.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions