Summary
The DeepSeek V4 MoE path is intended to run as:
hc_pre -> router -> dispatch -> expert -> combine
The same path passes on local a2a3sim, but fails on remote a2a3 NPU in the full MoE kernel. The observed full-path result is:
post_ffn: PASS
comb_ffn: PASS
ffn_out: FAIL
- typical
ffn_out max_abs_diff: 2.281
We tested two dispatch workarounds:
- Approach A: scalar
pl.write route table writes. This is closer to ideal data-dependent scatter, but route table writes are not reliably visible to later inline stages on NPU.
- Approach B: broadcast + static-mask dispatch. This avoids data-dependent writes by using fixed-offset
pl.assemble, but route buffers still become incorrect when the route inputs are small tensors produced by the router inside the same JIT kernel.
The current best description is: small route tensors produced inside one JIT kernel by the router (indices/weights, or padded variants) are expanded by dispatch into recv_weights/route_mask; when those route buffers are then consumed internally or exposed as top-level outputs, NPU backend behavior appears sensitive to data versioning, manual scope dependency, lifetime, or liveness.
Background
Ideal MoE dispatch uses the runtime router result to decide the destination expert and slot:
for t in range(T):
for k in range(TOPK):
e = indices[t, k]
p = slot[e]
recv_x[e, p, :] = x_norm[t, :]
recv_weights[e, p] = weights[t, k]
route_mask[e, p] = 1.0
slot[e] += 1
The current Python DSL cannot directly express this data-dependent tensor write. pl.assemble requires a compile-time literal destination offset, so it cannot directly implement recv_weights[indices[t, k], slot] = weights[t, k].
Verified approaches
| Approach |
Idea |
Expected |
Actual |
Current conclusion |
A. scalar pl.write route table writes |
Compute a data-dependent flat index and write recv_weights/route_mask with scalar pl.write |
Later inline/stage reads see the scalar writes |
Works in sim; unreliable on NPU across inline boundaries |
Need to define and fix, or reject, cross-inline pl.write visibility |
| B. broadcast + static-mask dispatch |
Broadcast recv_x to all experts and build recv_weights/route_mask with fixed-offset pl.assemble |
Avoid data-dependent writes; non-routed slots are masked out |
Host synthetic routes pass; router-produced route-only snapshot fails; full stage passes only when route outputs are exposed |
Current main issue is dependency/lifetime for router-produced small route tensors and fixed-offset route buffers |
Approach A: scalar pl.write route table writes
Approach
The large BF16 recv_x buffer is still broadcast with fixed offsets. The small FP32 route tables are written with scalar pl.write at a data-dependent flat index:
dst = indices[t, k] * RECV_MAX + k * T + t
recv_weights_flat[dst] = weights[t, k]
route_mask_flat[dst] = 1.0
Key code
@pl.jit.inline
def scalar_write_route_dispatch(
x_norm, # [T, D]
indices, # [T, TOPK]
weights, # [T, TOPK]
recv_x, # [E, RECV_MAX, D]
recv_weights, # [E, RECV_MAX]
route_mask, # [E, RECV_MAX]
):
recv_x_flat = pl.reshape(recv_x, [N_LOCAL_EXPERTS * RECV_MAX, D])
recv_weights_flat = pl.reshape(recv_weights, [N_LOCAL_EXPERTS * RECV_MAX])
route_mask_flat = pl.reshape(route_mask, [N_LOCAL_EXPERTS * RECV_MAX])
# Large tensor path: fixed-offset broadcast.
for e in pl.static_range(N_LOCAL_EXPERTS):
for k in pl.static_range(TOPK):
for d0 in pl.static_range(0, D, 64):
x_tile = pl.read(x_norm, [0, d0], [T, 64])
recv_x_flat = pl.assemble(
recv_x_flat,
x_tile,
[e * RECV_MAX + k * T, d0],
)
# Small route-table path: scalar writes with a data-dependent flat index.
for t in pl.static_range(T):
for k in pl.static_range(TOPK):
expert_id = pl.read(indices, [t, k])
weight = pl.read(weights, [t, k])
dst = expert_id * RECV_MAX + k * T + t
recv_weights_flat = pl.write(recv_weights_flat, [dst], weight)
route_mask_flat = pl.write(route_mask_flat, [dst], 1.0)
recv_weights = pl.reshape(recv_weights_flat, [N_LOCAL_EXPERTS, RECV_MAX])
route_mask = pl.reshape(route_mask_flat, [N_LOCAL_EXPERTS, RECV_MAX])
recv_x = pl.reshape(recv_x_flat, [N_LOCAL_EXPERTS, RECV_MAX, D])
return recv_x, recv_weights, route_mask
The consumer must read recv_weights/route_mask from a later inline/stage:
@pl.jit
def scalar_write_then_consume_case(..., recv_weights_out, route_mask_out, ffn_out):
x_norm, indices, weights = run_router(...)
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = scalar_write_route_dispatch(
x_norm, indices, weights, recv_x, recv_weights, route_mask,
)
recv_weights_out, route_mask_out = copy_route_snapshot(recv_weights, route_mask)
recv_y, sh = moe_expert(recv_x, recv_weights, x_norm, expert_weights, shared_weights)
combine(recv_y, route_mask, sh, ffn_out)
return recv_weights_out, route_mask_out, ffn_out
Expected
Both sim and NPU should observe scalar pl.write updates in later inline/stage reads. Since @pl.jit.inline is still inside the same JIT kernel, users would expect a tensor write in one inline block to be visible to a later inline block.
Actual
Sim observes the writes in execution order. NPU later reads are unreliable and may observe old or unwritten values.
Analysis
This should be treated as a compiler/codegen bug or an undefined semantic gap, not as a normal design constraint. If scalar pl.write is not guaranteed to be visible across inline boundaries, it should be documented and diagnosed at compile time instead of silently producing incorrect NPU results.
Approach B: broadcast + static-mask dispatch
Approach
Approach B is the current main path. It avoids scalar pl.write by making all destination offsets compile-time constants:
recv_x[e, k*T+t, :] is a fixed-offset broadcast of x_norm[t, :].
recv_weights[e, k*T+t] and route_mask[e, k*T+t] are generated with an arithmetic equality mask from indices[t, k] == e.
- Non-routed slots are zeroed by
recv_weights=0 and route_mask=0.
The issue is that when indices/weights are small tensors produced by the router inside the same JIT kernel, NPU route buffer snapshots become incorrect, even though a host-synthetic route control case passes.
Key code: fixed-offset dispatch
T = 16
D = 4096
TOPK = 2
N_LOCAL_EXPERTS = 8
RECV_MAX = T * TOPK
@pl.jit.inline
def broadcast_static_mask_dispatch(
x_norm, # [T, D]
indices, # [T, TOPK], router output
weights, # [T, TOPK], router output
recv_x_flat, # [N_LOCAL_EXPERTS * RECV_MAX, D]
recv_weights, # [N_LOCAL_EXPERTS, RECV_MAX]
route_mask, # [N_LOCAL_EXPERTS, RECV_MAX]
expert_table, # [N_LOCAL_EXPERTS, RECV_MAX], host-filled expert id table
):
for e in pl.parallel(N_LOCAL_EXPERTS):
for k in pl.static_range(TOPK):
for d0 in pl.static_range(0, D, 64):
x_tile = pl.read(x_norm, [0, d0], [T, 64])
recv_x_flat = pl.assemble(
recv_x_flat,
x_tile,
[e * RECV_MAX + k * T, d0],
)
idx = pl.read(indices, [0, k], [T, 1])
w = pl.read(weights, [0, k], [T, 1])
expert_id = pl.read(expert_table, [e, k * T], [1, T])
expert_id = pl.reshape(expert_id, [T, 1])
# mask = 1 - min(abs(idx - expert_id), 1)
diff = pl.abs(pl.sub(idx, expert_id))
clipped = pl.minimum(diff, pl.full((T, 1), value=1.0, dtype=pl.FP32))
mask = pl.sub(pl.full((T, 1), value=1.0, dtype=pl.FP32), clipped)
masked_weight = pl.mul(w, mask)
recv_weights = pl.assemble(recv_weights, masked_weight, [e, k * T])
route_mask = pl.assemble(route_mask, mask, [e, k * T])
return recv_x_flat, recv_weights, route_mask
Key code: route-only minimal repro
This removes real expert, combine, and fake readers. It keeps only:
hc_pre -> router -> dispatch -> route snapshot
@pl.jit
def route_snapshot_case(
x,
norm_w,
gate_w,
gate_b,
recv_x_flat,
recv_weights,
route_mask,
expert_table,
):
x_norm = hc_pre(x, norm_w)
# router_padded returns x_norm, indices_pad, and weights_pad.
# indices_pad/weights_pad are small route tensors produced inside this JIT.
x_norm, indices_pad, weights_pad = router_padded(
x_norm,
gate_w,
gate_b,
)
recv_x_flat, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm,
indices_pad,
weights_pad,
recv_x_flat,
recv_weights,
route_mask,
expert_table,
)
return x_norm, recv_weights, route_mask
Expected CPU golden:
expected_weights = zeros([N_LOCAL_EXPERTS, RECV_MAX], dtype=float32)
expected_mask = zeros([N_LOCAL_EXPERTS, RECV_MAX], dtype=float32)
for t in range(T):
for k in range(TOPK):
e = int(indices[t, k])
slot = k * T + t
expected_weights[e, slot] = float(weights[t, k])
expected_mask[e, slot] = 1.0
Expected outputs:
x_norm matches CPU golden.
recv_weights == expected_weights.
route_mask == expected_mask.
Approach B variants tested
| Variant |
Change |
Expected |
Actual |
| padded route |
Router produces padded indices_pad/weights_pad; dispatch reads padded route tables |
PASS |
x_norm PASS, recv_weights/route_mask FAIL |
| compact-touch |
Dispatch reads compact indices/weights; router side adds explicit read/touch to try to preserve route tensors |
Improves if padded small-table transfer is the trigger |
FAIL; touch variables may still be reported unused |
| padded-return |
Dispatch explicitly returns recv_weights/route_mask instead of relying only on InOut side effects |
Improves if missing return dependency is the trigger |
FAIL |
Approach B layered tests
The following cases are not separate implementation approaches. They are layered repros for Approach B. Each case changes one key variable: route source, x_norm exposure, route buffer exposure, and downstream consumer type.
C0: synthetic route input + real internal consumer
This bypasses hc_pre/router. x_norm/indices/weights are host/CPU-provided JIT inputs. Dispatch-produced recv_weights/route_mask are consumed internally by real moe_expert and combine, while route snapshots are copied before and after the consumer for observation.
@pl.jit
def case_synthetic_route_real_consumer(
x_norm, indices, weights, expert_weights, shared_weights, expert_table,
recv_weights_pre, route_mask_pre, recv_weights_post, route_mask_post,
recv_y_out, sh_out, ffn_out,
):
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm, indices, weights, recv_x, recv_weights, route_mask, expert_table,
)
recv_weights_pre, route_mask_pre = copy_route_snapshot(recv_weights, route_mask)
moe_expert(recv_x, recv_weights, x_norm, expert_weights, shared_weights, recv_y_out, sh_out)
recv_weights_post, route_mask_post = copy_route_snapshot(recv_weights, route_mask)
combine(recv_y_out, route_mask, sh_out, ffn_out)
return recv_weights_pre, route_mask_pre, recv_weights_post, route_mask_post, recv_y_out, sh_out, ffn_out
Purpose: establish that dispatch, real expert, combine, and internal route buffer consumption are not unconditionally broken when route inputs come from host tensors.
C1: real router + fake consumer + top-level x_norm
The real hc_pre/router produces x_norm/indices_pad/weights_pad. The router writes x_norm directly to a top-level output, and dispatch uses that same output. The downstream consumer is a lightweight route-table reader, not real expert.
@pl.jit
def case_router_fake_consumer_top_level_xnorm(..., x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post):
x_mixed = hc_pre(...)
indices_pad, weights_pad = create_route_tensors()
router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm_out, indices_pad, weights_pad)
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm_out, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
)
recv_weights_pre, route_mask_pre = copy_route_snapshot(recv_weights, route_mask)
fake_weighted_consumer(recv_x, recv_weights, route_mask, fake_out)
recv_weights_post, route_mask_post = copy_route_snapshot(recv_weights, route_mask)
return x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post
The fake consumer only reads recv_x/recv_weights/route_mask and accumulates weighted values:
def fake_weighted_consumer(recv_x, recv_weights, route_mask, fake_out):
for d0 in chunks(D):
accum = zeros([T, chunk])
for e in range(N_LOCAL_EXPERTS):
for k in range(TOPK):
row0 = e * RECV_MAX + k * T
x_blk = recv_x[row0 : row0 + T, d0 : d0 + chunk]
w = recv_weights[row0 : row0 + T]
m = route_mask[row0 : row0 + T]
accum += x_blk * w[:, None] * m[:, None]
fake_out[:, d0 : d0 + chunk] = accum
Purpose: check whether a lightweight downstream reader is enough to preserve/read route buffers. Actual results show fake_out can PASS while route snapshots already FAIL, so this consumer is not a sufficient correctness check.
C2: real router + fake consumer + internal x_norm
Same as C1, except router writes x_norm to an internal tensor. x_norm_out is only a copied observation output; dispatch uses the internal x_norm.
@pl.jit
def case_router_fake_consumer_internal_xnorm(..., x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post):
x_mixed = hc_pre(...)
x_norm = pl.create_tensor([T, D], dtype=pl.BF16)
indices_pad, weights_pad = create_route_tensors()
router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm, indices_pad, weights_pad)
copy_x_norm(x_norm, x_norm_out)
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
)
recv_weights_pre, route_mask_pre = copy_route_snapshot(recv_weights, route_mask)
fake_weighted_consumer(recv_x, recv_weights, route_mask, fake_out)
recv_weights_post, route_mask_post = copy_route_snapshot(recv_weights, route_mask)
return x_norm_out, recv_weights_pre, route_mask_pre, fake_out, recv_weights_post, route_mask_post
Purpose: isolate whether top-level vs internal x_norm affects route table propagation.
C3: route-only padded
Keeps real hc_pre/router and padded route tables, but removes fake consumer, real expert, and combine. Dispatch immediately snapshots recv_weights/route_mask.
@pl.jit
def case_route_only_padded(..., x_norm_out, recv_weights_out, route_mask_out):
x_mixed = hc_pre(...)
x_norm = pl.create_tensor([T, D], dtype=pl.BF16)
indices_pad, weights_pad = create_padded_route_tensors()
router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm, indices_pad, weights_pad)
copy_x_norm(x_norm, x_norm_out)
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
)
recv_weights_out, route_mask_out = copy_route_snapshot(recv_weights, route_mask)
return x_norm_out, recv_weights_out, route_mask_out
Purpose: minimal failure path. No real expert, combine, or fake reader is required for route snapshots to become wrong.
C4: route-only compact-touch
Same as C3, except dispatch reads compact indices/weights instead of padded indices_pad/weights_pad. The router side also reads/touches compact route tensors to try to preserve them.
@pl.jit
def case_route_only_compact_touch(...):
x_mixed = hc_pre(...)
x_norm = pl.create_tensor([T, D], dtype=pl.BF16)
indices = pl.create_tensor([T, TOPK], dtype=pl.INT32)
weights = pl.create_tensor([T, TOPK], dtype=pl.FP32)
indices_pad, weights_pad = create_padded_route_tensors()
router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm, indices, weights, indices_pad, weights_pad)
copy_x_norm(x_norm, x_norm_out)
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm, indices, weights, recv_x, recv_weights, route_mask, expert_table,
)
return x_norm_out, copy_route_snapshot(recv_weights, route_mask)
Purpose: rule out the explanation that only padded route tables fail, and check whether touch-to-pin creates a reliable dependency.
C5: route-only padded-return
Same as C3, except dispatch explicitly returns recv_weights/route_mask instead of relying only on InOut side effects.
@pl.jit.inline
def dispatch_return_route(x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table):
recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
)
return recv_x, recv_weights, route_mask
@pl.jit
def case_route_only_padded_return(...):
x_norm, indices_pad, weights_pad = run_router_padded(...)
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = dispatch_return_route(
x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
)
return x_norm_out, copy_route_snapshot(recv_weights, route_mask)
Purpose: rule out missing return dependency for route buffers as the only issue.
C6: full stage debug with route outputs exposed
Full chain, but key intermediate values are top-level outputs: x_norm, recv_x, recv_weights, route_mask, recv_y, sh, and ffn_out. recv_weights_out/route_mask_out are not just observations; they are directly consumed by real expert and combine.
@pl.jit
def case_full_stage_outputs(..., x_norm_out, recv_x_out, recv_weights_out, route_mask_out, recv_y_out, sh_out, ffn_out):
x_mixed = hc_pre(...)
indices_pad, weights_pad = create_padded_route_tensors()
router_padded(x_mixed, norm_w, gate_w, gate_bias, x_norm_out, indices_pad, weights_pad)
recv_x = create_recv_x()
recv_x, recv_weights_out, route_mask_out = broadcast_static_mask_dispatch(
x_norm_out, indices_pad, weights_pad, recv_x, recv_weights_out, route_mask_out, expert_table,
)
copy_recv_x(recv_x, recv_x_out)
moe_expert(recv_x, recv_weights_out, x_norm_out, expert_weights, shared_weights, recv_y_out, sh_out)
combine(recv_y_out, route_mask_out, sh_out, ffn_out)
return ffn_out, x_norm_out, recv_x_out, recv_y_out, sh_out, recv_weights_out, route_mask_out
Purpose: test whether top-level output exposure changes route buffer liveness/lifetime. This case passes, which suggests the full math can work when route buffers are kept alive as outputs.
C7: full MoE without route outputs
Target shape. x_norm/recv_weights/route_mask/recv_x/recv_y/sh are internal tensors; the kernel returns only final MoE outputs and a few upstream debug outputs. Route buffers are materialized internally before real expert/combine.
@pl.jit
def case_full_moe_internal_route(..., ffn_out):
x_mixed = hc_pre(...)
x_norm, indices_pad, weights_pad = run_router_padded(...)
recv_x, recv_weights, route_mask = create_dispatch_buffers()
recv_x, recv_weights, route_mask = broadcast_static_mask_dispatch(
x_norm, indices_pad, weights_pad, recv_x, recv_weights, route_mask, expert_table,
)
recv_weights_use, route_mask_use = materialize_route_buffers(recv_weights, route_mask)
recv_y, sh = moe_expert(recv_x, recv_weights_use, x_norm, expert_weights, shared_weights)
combine(recv_y, route_mask_use, sh, ffn_out)
return ffn_out
Purpose: target path. This fails while C6 passes, making route-buffer liveness/top-level-output behavior a key suspect.
Approach B results
| ID |
Actual |
Key output |
Conclusion |
| C0 |
PASS |
recv_weights_pre/post, route_mask_pre/post, recv_y_out/sh_out/ffn_out PASS |
Host-synthetic route inputs can be consumed internally by real expert/combine. |
| C1 |
FAIL |
x_norm_out PASS, fake_out PASS, recv_weights_pre/post, route_mask_pre/post FAIL |
Fake consumer is not enough to prove route table correctness. |
| C2 |
FAIL |
x_norm_out/fake_out PASS, route snapshot FAIL |
Internal x_norm path also fails, but is not the only trigger. |
| C3 |
FAIL |
x_norm_out PASS, recv_weights_out/route_mask_out FAIL |
Minimal failure path; real expert/combine are not required. |
| C4 |
FAIL |
route snapshot FAIL |
Compact/flat read and touch-to-pin are not enough. |
| C5 |
FAIL |
route snapshot FAIL |
Explicitly returning route buffers is not enough. |
| C6 |
PASS |
ffn_out/post_ffn/comb_ffn/x_norm_out/recv_x_out/recv_y_out/sh_out/recv_weights_out/route_mask_out all PASS |
Full chain can pass when route buffers are top-level outputs. |
| C7 |
FAIL |
post_ffn/comb_ffn PASS, ffn_out FAIL, max_abs_diff=2.281 |
Target path still fails. |
Typical route snapshot failure:
recv_weights has many zero-expected slots with actual=0.5.
route_mask has many zero-expected slots with actual=1.0.
- Some real routed slots are expected nonzero but read as 0.
This looks like a data versioning/lifetime problem for route tensors or expanded route buffers, not a numerical precision issue.
Broader implementation options
Option 1: Data-dependent scatter
Plan: use indices[t, k] to directly choose recv_x/recv_weights/route_mask destinations and implement compact per-expert dispatch.
Expected: correct semantics, no non-routed slots, no mask overhead.
Actual: current PyPTO tensor APIs cannot express this directly. pl.assemble requires literal offsets, and Approach A's scalar pl.write workaround is unreliable across inline boundaries on NPU.
Open issue: provide reliable data-dependent scatter/gather APIs, or define and fix/reject scalar pl.write cross-inline visibility.
Option 2: Packed dispatch
Plan: store all (token, topk) pairs as [T*TOPK, D], with per-row expert id and route weight. This avoids data-dependent writes.
Expected: fixed write offsets, every row is valid, no mask.
Actual: not implemented end-to-end.
Open issue: expert kernels then need to select expert weights by per-row recv_expert_id, which is a data-dependent read source and is not currently expressible with static pl.slice style indexing.
Option 3: Broadcast + static-mask dispatch
Plan: current main path. Broadcast recv_x to all experts and generate recv_weights/route_mask with fixed-offset pl.assemble.
Expected: avoids data-dependent writes, while expert weight selection remains static by expert id.
Actual: sim passes, but NPU behavior depends on route source and top-level output exposure. Route-only snapshots fail when routes are produced by the router inside the same JIT.
Open issue: route tensor and route buffer dependency/lifetime across producer/consumer stages.
Option 4: All-experts broadcast
Plan: compute every expert for every token and combine with gate weights.
Expected: avoids dispatch and route-buffer issues.
Actual: works for small expert counts.
Open issue: computationally unacceptable for DeepSeek V4 scale.
Option 5: mscatter hybrid
Plan: keep fixed-offset broadcast for recv_x, but write recv_weights/route_mask with tile-level mscatter to GM.
Expected: closer to real scatter semantics and may avoid fixed-offset route-buffer lifetime issues.
Actual: no complete runnable implementation yet.
Open issue: no high-level API, index/source tile construction is complex, GM writes must be validated across later inline reads, aliasing must be avoided, dtype constraints apply.
Option 6: Borrow flat read / SSA style from sparse attention block_table
Plan: sparse attention successfully uses host-provided block_table with flat reads and local SSA-style assembly:
ori_block_table_flat = pl.reshape(ori_block_table, [B * ORI_MAX_BLOCKS])
cmp_block_table_flat = pl.reshape(cmp_block_table, [B * CMP_MAX_BLOCKS])
raw_idx = pl.read(cmp_sparse_indices_flat, [raw_idx_pos])
ori_slot = raw_idx // BLOCK_SIZE
ori_block_pos = ori_block_base + ori_slot
ori_blk = pl.read(ori_block_table_flat, [ori_block_pos])
ori_row = ori_blk * BLOCK_SIZE + ori_intra
kv_topk_batch = pl.assemble(kv_topk_batch, ori_kv_flat[ori_row : ori_row + 1, 0 : HEAD_DIM], [kk, 0])
This pattern was tested in MoE via compact route and explicit route-buffer return variants.
Expected: if the issue is only padded small-table transfer or missing InOut dependency, compact route or explicit return should help.
Actual: both variants fail. x_norm still passes, but recv_weights/route_mask snapshots fail. The compact-touch path may still report touch variables as unused.
Why sparse attention works but MoE does not directly inherit it:
- Sparse attention
block_table is a host readonly input; MoE indices/weights are small tensors dynamically produced by router inside the same JIT.
- Sparse attention reads
block_table and immediately assembles local scratch in the same logic; MoE passes small route tensors across router -> dispatch -> expert/combine.
- Sparse attention does not need to produce mutable
recv_weights/route_mask for later stages; MoE does.
- MoE results change depending on whether route buffers are top-level outputs, which points to liveness or memory reuse interactions not present in the
block_table path.
Open issue: flat read / explicit return is not sufficient. The part worth borrowing is the "produce local scratch and consume it immediately" pattern, which suggests the fused router-dispatch option below.
Option 7: Fuse router and dispatch route-table generation
Plan: fuse router top-k and dispatch route-table expansion into the same inline/stage, avoiding indices_pad/weights_pad as cross-inline small tensors.
Expected: shorter route tensor lifetime and fewer producer/consumer boundaries.
Actual: not implemented yet.
Open issue: if this passes, the current router/dispatch boundary is the trigger; if it fails, the issue points more directly at fixed-offset route-table outputs and consumer liveness.
Current assessment
- Approach A shows scalar
pl.write cross-inline visibility needs a defined semantic contract; current NPU behavior should not silently compute wrong results.
- Approach B avoids scalar
pl.write, but still fails when router-produced route tensors feed fixed-offset route buffers, so the issue is broader than the old pl.write workaround.
- Host-synthetic route inputs pass, so it is not true that all internal route-buffer consumers are broken.
- The full stage-debug path passes when route buffers are top-level outputs, while the target path without route outputs fails. Route-buffer liveness, memory reuse, or top-level-output scheduling behavior is likely involved.
Questions for PyPTO/PTOAS/backend owners
- Does PyPTO/PTOAS guarantee that an internal tensor written by fixed-offset
pl.assemble in one inline/stage is visible with correct dependency and lifetime to later inline consumers?
- How are InOut side effects, inline return values, top-level outputs, manual scopes, and memory reuse modeled for small internal tensors such as
indices/weights and recv_weights/route_mask?
- Why does the host-synthetic route input case pass, while the router-produced route-only snapshot fails?
- Why does the full stage-debug path pass when route buffers are top-level outputs, while the target path fails when route buffers are internal?
- Should scalar
pl.write be visible across @pl.jit.inline boundaries inside the same JIT kernel? If yes, the current NPU behavior needs a fix. If no, the compiler should reject or warn instead of silently producing incorrect results.
- Should the DSL expose a reliable data-dependent scatter/gather primitive suitable for MoE dispatch, rather than requiring broadcast+mask workarounds?
Suggested issue attachments
- Full remote a2a3 NPU logs or CI artifacts.
- First mismatches for
recv_weights/route_mask.
- PASS/FAIL matrix for the control and failing cases.
Summary
The DeepSeek V4 MoE path is intended to run as:
The same path passes on local
a2a3sim, but fails on remote a2a3 NPU in the full MoE kernel. The observed full-path result is:post_ffn: PASScomb_ffn: PASSffn_out: FAILffn_out max_abs_diff:2.281We tested two dispatch workarounds:
pl.writeroute table writes. This is closer to ideal data-dependent scatter, but route table writes are not reliably visible to later inline stages on NPU.pl.assemble, but route buffers still become incorrect when the route inputs are small tensors produced by the router inside the same JIT kernel.The current best description is: small route tensors produced inside one JIT kernel by the router (
indices/weights, or padded variants) are expanded by dispatch intorecv_weights/route_mask; when those route buffers are then consumed internally or exposed as top-level outputs, NPU backend behavior appears sensitive to data versioning, manual scope dependency, lifetime, or liveness.Background
Ideal MoE dispatch uses the runtime router result to decide the destination expert and slot:
The current Python DSL cannot directly express this data-dependent tensor write.
pl.assemblerequires a compile-time literal destination offset, so it cannot directly implementrecv_weights[indices[t, k], slot] = weights[t, k].Verified approaches
pl.writeroute table writesrecv_weights/route_maskwith scalarpl.writepl.writevisibilityrecv_xto all experts and buildrecv_weights/route_maskwith fixed-offsetpl.assembleApproach A: scalar
pl.writeroute table writesApproach
The large BF16
recv_xbuffer is still broadcast with fixed offsets. The small FP32 route tables are written with scalarpl.writeat a data-dependent flat index:Key code
The consumer must read
recv_weights/route_maskfrom a later inline/stage:Expected
Both sim and NPU should observe scalar
pl.writeupdates in later inline/stage reads. Since@pl.jit.inlineis still inside the same JIT kernel, users would expect a tensor write in one inline block to be visible to a later inline block.Actual
Sim observes the writes in execution order. NPU later reads are unreliable and may observe old or unwritten values.
Analysis
This should be treated as a compiler/codegen bug or an undefined semantic gap, not as a normal design constraint. If scalar
pl.writeis not guaranteed to be visible across inline boundaries, it should be documented and diagnosed at compile time instead of silently producing incorrect NPU results.Approach B: broadcast + static-mask dispatch
Approach
Approach B is the current main path. It avoids scalar
pl.writeby making all destination offsets compile-time constants:recv_x[e, k*T+t, :]is a fixed-offset broadcast ofx_norm[t, :].recv_weights[e, k*T+t]androute_mask[e, k*T+t]are generated with an arithmetic equality mask fromindices[t, k] == e.recv_weights=0androute_mask=0.The issue is that when
indices/weightsare small tensors produced by the router inside the same JIT kernel, NPU route buffer snapshots become incorrect, even though a host-synthetic route control case passes.Key code: fixed-offset dispatch
Key code: route-only minimal repro
This removes real expert, combine, and fake readers. It keeps only:
Expected CPU golden:
Expected outputs:
x_normmatches CPU golden.recv_weights == expected_weights.route_mask == expected_mask.Approach B variants tested
indices_pad/weights_pad; dispatch reads padded route tablesx_normPASS,recv_weights/route_maskFAILindices/weights; router side adds explicit read/touch to try to preserve route tensorsrecv_weights/route_maskinstead of relying only on InOut side effectsApproach B layered tests
The following cases are not separate implementation approaches. They are layered repros for Approach B. Each case changes one key variable: route source,
x_normexposure, route buffer exposure, and downstream consumer type.C0: synthetic route input + real internal consumer
This bypasses
hc_pre/router.x_norm/indices/weightsare host/CPU-provided JIT inputs. Dispatch-producedrecv_weights/route_maskare consumed internally by realmoe_expertandcombine, while route snapshots are copied before and after the consumer for observation.Purpose: establish that dispatch, real expert, combine, and internal route buffer consumption are not unconditionally broken when route inputs come from host tensors.
C1: real router + fake consumer + top-level
x_normThe real
hc_pre/routerproducesx_norm/indices_pad/weights_pad. The router writesx_normdirectly to a top-level output, and dispatch uses that same output. The downstream consumer is a lightweight route-table reader, not real expert.The fake consumer only reads
recv_x/recv_weights/route_maskand accumulates weighted values:Purpose: check whether a lightweight downstream reader is enough to preserve/read route buffers. Actual results show
fake_outcan PASS while route snapshots already FAIL, so this consumer is not a sufficient correctness check.C2: real router + fake consumer + internal
x_normSame as C1, except router writes
x_normto an internal tensor.x_norm_outis only a copied observation output; dispatch uses the internalx_norm.Purpose: isolate whether top-level vs internal
x_normaffects route table propagation.C3: route-only padded
Keeps real
hc_pre/routerand padded route tables, but removes fake consumer, real expert, and combine. Dispatch immediately snapshotsrecv_weights/route_mask.Purpose: minimal failure path. No real expert, combine, or fake reader is required for route snapshots to become wrong.
C4: route-only compact-touch
Same as C3, except dispatch reads compact
indices/weightsinstead of paddedindices_pad/weights_pad. The router side also reads/touches compact route tensors to try to preserve them.Purpose: rule out the explanation that only padded route tables fail, and check whether touch-to-pin creates a reliable dependency.
C5: route-only padded-return
Same as C3, except dispatch explicitly returns
recv_weights/route_maskinstead of relying only on InOut side effects.Purpose: rule out missing return dependency for route buffers as the only issue.
C6: full stage debug with route outputs exposed
Full chain, but key intermediate values are top-level outputs:
x_norm,recv_x,recv_weights,route_mask,recv_y,sh, andffn_out.recv_weights_out/route_mask_outare not just observations; they are directly consumed by real expert and combine.Purpose: test whether top-level output exposure changes route buffer liveness/lifetime. This case passes, which suggests the full math can work when route buffers are kept alive as outputs.
C7: full MoE without route outputs
Target shape.
x_norm/recv_weights/route_mask/recv_x/recv_y/share internal tensors; the kernel returns only final MoE outputs and a few upstream debug outputs. Route buffers are materialized internally before real expert/combine.Purpose: target path. This fails while C6 passes, making route-buffer liveness/top-level-output behavior a key suspect.
Approach B results
recv_weights_pre/post,route_mask_pre/post,recv_y_out/sh_out/ffn_outPASSx_norm_outPASS,fake_outPASS,recv_weights_pre/post,route_mask_pre/postFAILx_norm_out/fake_outPASS, route snapshot FAILx_normpath also fails, but is not the only trigger.x_norm_outPASS,recv_weights_out/route_mask_outFAILffn_out/post_ffn/comb_ffn/x_norm_out/recv_x_out/recv_y_out/sh_out/recv_weights_out/route_mask_outall PASSpost_ffn/comb_ffnPASS,ffn_outFAIL,max_abs_diff=2.281Typical route snapshot failure:
recv_weightshas many zero-expected slots withactual=0.5.route_maskhas many zero-expected slots withactual=1.0.This looks like a data versioning/lifetime problem for route tensors or expanded route buffers, not a numerical precision issue.
Broader implementation options
Option 1: Data-dependent scatter
Plan: use
indices[t, k]to directly chooserecv_x/recv_weights/route_maskdestinations and implement compact per-expert dispatch.Expected: correct semantics, no non-routed slots, no mask overhead.
Actual: current PyPTO tensor APIs cannot express this directly.
pl.assemblerequires literal offsets, and Approach A's scalarpl.writeworkaround is unreliable across inline boundaries on NPU.Open issue: provide reliable data-dependent scatter/gather APIs, or define and fix/reject scalar
pl.writecross-inline visibility.Option 2: Packed dispatch
Plan: store all
(token, topk)pairs as[T*TOPK, D], with per-row expert id and route weight. This avoids data-dependent writes.Expected: fixed write offsets, every row is valid, no mask.
Actual: not implemented end-to-end.
Open issue: expert kernels then need to select expert weights by per-row
recv_expert_id, which is a data-dependent read source and is not currently expressible with staticpl.slicestyle indexing.Option 3: Broadcast + static-mask dispatch
Plan: current main path. Broadcast
recv_xto all experts and generaterecv_weights/route_maskwith fixed-offsetpl.assemble.Expected: avoids data-dependent writes, while expert weight selection remains static by expert id.
Actual: sim passes, but NPU behavior depends on route source and top-level output exposure. Route-only snapshots fail when routes are produced by the router inside the same JIT.
Open issue: route tensor and route buffer dependency/lifetime across producer/consumer stages.
Option 4: All-experts broadcast
Plan: compute every expert for every token and combine with gate weights.
Expected: avoids dispatch and route-buffer issues.
Actual: works for small expert counts.
Open issue: computationally unacceptable for DeepSeek V4 scale.
Option 5:
mscatterhybridPlan: keep fixed-offset broadcast for
recv_x, but writerecv_weights/route_maskwith tile-levelmscatterto GM.Expected: closer to real scatter semantics and may avoid fixed-offset route-buffer lifetime issues.
Actual: no complete runnable implementation yet.
Open issue: no high-level API, index/source tile construction is complex, GM writes must be validated across later inline reads, aliasing must be avoided, dtype constraints apply.
Option 6: Borrow flat read / SSA style from sparse attention
block_tablePlan: sparse attention successfully uses host-provided
block_tablewith flat reads and local SSA-style assembly:This pattern was tested in MoE via compact route and explicit route-buffer return variants.
Expected: if the issue is only padded small-table transfer or missing InOut dependency, compact route or explicit return should help.
Actual: both variants fail.
x_normstill passes, butrecv_weights/route_masksnapshots fail. The compact-touch path may still report touch variables as unused.Why sparse attention works but MoE does not directly inherit it:
block_tableis a host readonly input; MoEindices/weightsare small tensors dynamically produced by router inside the same JIT.block_tableand immediately assembles local scratch in the same logic; MoE passes small route tensors acrossrouter -> dispatch -> expert/combine.recv_weights/route_maskfor later stages; MoE does.block_tablepath.Open issue: flat read / explicit return is not sufficient. The part worth borrowing is the "produce local scratch and consume it immediately" pattern, which suggests the fused router-dispatch option below.
Option 7: Fuse router and dispatch route-table generation
Plan: fuse router top-k and dispatch route-table expansion into the same inline/stage, avoiding
indices_pad/weights_padas cross-inline small tensors.Expected: shorter route tensor lifetime and fewer producer/consumer boundaries.
Actual: not implemented yet.
Open issue: if this passes, the current router/dispatch boundary is the trigger; if it fails, the issue points more directly at fixed-offset route-table outputs and consumer liveness.
Current assessment
pl.writecross-inline visibility needs a defined semantic contract; current NPU behavior should not silently compute wrong results.pl.write, but still fails when router-produced route tensors feed fixed-offset route buffers, so the issue is broader than the oldpl.writeworkaround.Questions for PyPTO/PTOAS/backend owners
pl.assemblein one inline/stage is visible with correct dependency and lifetime to later inline consumers?indices/weightsandrecv_weights/route_mask?pl.writebe visible across@pl.jit.inlineboundaries inside the same JIT kernel? If yes, the current NPU behavior needs a fix. If no, the compiler should reject or warn instead of silently producing incorrect results.Suggested issue attachments
recv_weights/route_mask.