diff --git a/models/deepseek/v4/compressor_ratio4.py b/models/deepseek/v4/compressor_ratio4.py index efa492a..aaa2382 100644 --- a/models/deepseek/v4/compressor_ratio4.py +++ b/models/deepseek/v4/compressor_ratio4.py @@ -22,10 +22,14 @@ COMPRESS_RATIO = 4 HEAD_DIM = 512 -ROTATE = False +ROTATE = True +MAX_SEQ_LEN = 4096 +IDX_KV_LEN = MAX_SEQ_LEN // COMPRESS_RATIO +FP32_NEG_INF = -3.4028234663852886e38 D = 4096 ROPE_HEAD_DIM = 64 +ROPE_CHUCK = 32 NOPE_HEAD_DIM = HEAD_DIM - ROPE_HEAD_DIM OVERLAP = COMPRESS_RATIO == 4 COFF = 1 + int(OVERLAP) @@ -41,331 +45,384 @@ HEAD_DIM_INV = 1.0 / HEAD_DIM K_CHUNK = 512 -OUT_CHUNK = 64 +OUT_CHUNK = 128 K_BLOCKS = D // K_CHUNK OUT_BLOCKS = OUT_DIM // OUT_CHUNK HEAD_CHUNK = 128 HEAD_BLOCKS = HEAD_DIM // HEAD_CHUNK +HEAD_DIM_CHUCK = 128 @pl.jit.inline def compressor( x: pl.Tensor[[B, S, D], pl.BF16], + kv: pl.Tensor[[B, S, HEAD_DIM], pl.FP32], kv_state: pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32], score_state: pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32], - wkv: pl.Tensor[[OUT_DIM, D], pl.BF16], - wgate: pl.Tensor[[OUT_DIM, D], pl.BF16], + wkv: pl.Tensor[[D, OUT_DIM], pl.BF16], + wgate: pl.Tensor[[D, OUT_DIM], pl.BF16], ape: pl.Tensor[[COMPRESS_RATIO, OUT_DIM], pl.FP32], - norm_w: pl.Tensor[[HEAD_DIM], pl.BF16], - cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], - sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], + norm_w: pl.Tensor[[HEAD_DIM], pl.FP32], + cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + even_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + odd_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], hadamard: pl.Tensor[[HEAD_DIM, HEAD_DIM], pl.BF16], + kv_cache: pl.Tensor[[B, IDX_KV_LEN, HEAD_DIM], pl.BF16], start_pos: pl.Scalar[pl.INT32], - out: pl.Tensor[[B, HEAD_DIM], pl.BF16], + rotate: pl.Scalar[pl.BOOL], ): - x_flat = pl.reshape(x, [B, D]) + x_flat = pl.reshape(x, [B * S, D]) + kv_proj = pl.create_tensor([B * S, OUT_DIM], dtype=pl.FP32) + score_proj = pl.create_tensor([B * S, OUT_DIM], dtype=pl.FP32) + ape_row = pl.cast(start_pos % COMPRESS_RATIO, target_type=pl.INDEX) + compress_rem = (start_pos + 1) % COMPRESS_RATIO + score_ape = pl.create_tensor([B * S, OUT_DIM], dtype=pl.FP32) + scatter_slot = COMPRESS_RATIO + ape_row + state_col0 = scatter_slot * OUT_DIM kv_state_flat = pl.reshape(kv_state, [B, STATE_LEN * OUT_DIM]) score_state_flat = pl.reshape(score_state, [B, STATE_LEN * OUT_DIM]) + for o0 in pl.parallel(0, OUT_DIM, OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_score_proj"): + x_tile = x_flat[:, 0 : K_CHUNK] + wkv_tile = wkv[0 : K_CHUNK, o0 : o0 + OUT_CHUNK] + wgate_tile = wgate[0 : K_CHUNK, o0 : o0 + OUT_CHUNK] + kv_acc = pl.matmul(x_tile, wkv_tile, out_dtype=pl.FP32) + score_acc = pl.matmul(x_tile, wgate_tile, out_dtype=pl.FP32) + + for k0 in pl.range(K_CHUNK, D, K_CHUNK): + x_tile = x_flat[:, k0 : k0 + K_CHUNK] + wkv_tile = wkv[k0 : k0 + K_CHUNK, o0 : o0 + OUT_CHUNK] + wgate_tile = wgate[k0 : k0 + K_CHUNK, o0 : o0 + OUT_CHUNK] + kv_acc = pl.matmul_acc(kv_acc, x_tile, wkv_tile) + score_acc = pl.matmul_acc(score_acc, x_tile, wgate_tile) + + kv_proj = pl.assemble(kv_proj, kv_acc, [0, o0]) + score_proj = pl.assemble(score_proj, score_acc, [0, o0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_ape"): + score_tile = score_proj[:, o0 : o0 + OUT_CHUNK] + ape_tile = ape[ape_row : ape_row + 1, o0 : o0 + OUT_CHUNK] + ape_base = pl.full([B * S, OUT_CHUNK], dtype=pl.FP32, value=0.0) + score_tile = pl.add(score_tile, pl.col_expand(ape_base, ape_tile)) + score_ape = pl.assemble(score_ape, score_tile, [0, o0]) - kv_fp32 = pl.create_tensor([B, OUT_DIM], dtype=pl.FP32) - score_fp32 = pl.create_tensor([B, OUT_DIM], dtype=pl.FP32) - slot_off = SCATTER_SLOT * OUT_DIM - - for ob in pl.parallel(0, OUT_BLOCKS, 1): - oc0 = ob * OUT_CHUNK - # Block 1a (Cube): kv = x @ wkv.T - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_proj"): - a0 = pl.slice(x_flat, [B, K_CHUNK], [0, 0]) - b0 = pl.slice(wkv, [OUT_CHUNK, K_CHUNK], [oc0, 0]) - kv_acc = pl.matmul(a0, b0, out_dtype=pl.FP32, b_trans=True) - for kb in pl.range(1, K_BLOCKS): - a_i = pl.slice(x_flat, [B, K_CHUNK], [0, kb * K_CHUNK]) - b_i = pl.slice(wkv, [OUT_CHUNK, K_CHUNK], [oc0, kb * K_CHUNK]) - kv_acc = pl.matmul_acc(kv_acc, a_i, b_i, b_trans=True) - kv_fp32 = pl.assemble(kv_fp32, kv_acc, [0, oc0]) - - # Block 1b (Cube): score = x @ wgate.T - with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_proj"): - a0g = pl.slice(x_flat, [B, K_CHUNK], [0, 0]) - b0g = pl.slice(wgate, [OUT_CHUNK, K_CHUNK], [oc0, 0]) - sc_acc = pl.matmul(a0g, b0g, out_dtype=pl.FP32, b_trans=True) - for kb in pl.range(1, K_BLOCKS): - a_ig = pl.slice(x_flat, [B, K_CHUNK], [0, kb * K_CHUNK]) - b_ig = pl.slice(wgate, [OUT_CHUNK, K_CHUNK], [oc0, kb * K_CHUNK]) - sc_acc = pl.matmul_acc(sc_acc, a_ig, b_ig, b_trans=True) - score_fp32 = pl.assemble(score_fp32, sc_acc, [0, oc0]) - - # Block 2 (Vector): score += ape[APE_ROW] - with pl.at(level=pl.Level.CORE_GROUP, name_hint="ape_add"): - sc = pl.slice(score_fp32, [B, OUT_CHUNK], [0, oc0]) - ape_row = pl.slice(ape, [1, OUT_CHUNK], [APE_ROW, oc0]) - ones_b = pl.full([B, OUT_CHUNK], dtype=pl.FP32, value=1.0) - ape_broadcast = pl.col_expand_mul(ones_b, ape_row) - sc = pl.add(sc, ape_broadcast) - score_fp32 = pl.assemble(score_fp32, sc, [0, oc0]) - - # Block 3 (Vector): scatter current kv/score into state with pl.at(level=pl.Level.CORE_GROUP, name_hint="state_scatter"): - kv_chunk = pl.slice(kv_fp32, [B, OUT_CHUNK], [0, oc0]) - kv_state_flat = pl.assemble(kv_state_flat, kv_chunk, [0, slot_off + oc0]) - sc_chunk = pl.slice(score_fp32, [B, OUT_CHUNK], [0, oc0]) - score_state_flat = pl.assemble(score_state_flat, sc_chunk, [0, slot_off + oc0]) - - # Reshape state to per-state-row 2D views - kv_state_per_row = pl.reshape(kv_state_flat, [B * STATE_LEN, OUT_DIM]) - score_state_per_row = pl.reshape(score_state_flat, [B * STATE_LEN, OUT_DIM]) - - # Block 5+6 (Vector): softmax+pool over STATE_LEN slots via tree reduction. - pooled = pl.create_tensor([B, HEAD_DIM], dtype=pl.FP32) - for b_idx in pl.parallel(0, B, 1): - row_b = b_idx * STATE_LEN - with pl.at(level=pl.Level.CORE_GROUP, name_hint="softmax_pool"): - for hb in pl.range(HEAD_BLOCKS): + kv_tile = kv_proj[:, o0 : o0 + OUT_CHUNK] + score_tile = score_ape[:, o0 : o0 + OUT_CHUNK] + kv_state_flat = pl.assemble(kv_state_flat, kv_tile, [0, state_col0 + o0]) + score_state_flat = pl.assemble(score_state_flat, score_tile, [0, state_col0 + o0]) + + pooled_kv = pl.create_tensor([B * S, HEAD_DIM], dtype=pl.FP32) + normed_kv = pl.create_tensor([B * S, HEAD_DIM], dtype=pl.BF16) + kv_flat = pl.reshape(kv, [B * S, HEAD_DIM]) + + if compress_rem == 0: + for hb in pl.parallel(0, HEAD_BLOCKS, 1): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="softmax_pool"): h0 = hb * HEAD_CHUNK - - s0 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 0, h0]) - s1 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 1, h0]) - s2 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 2, h0]) - s3 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 3, h0]) - s4 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 4, HEAD_DIM + h0]) - s5 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 5, HEAD_DIM + h0]) - s6 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 6, HEAD_DIM + h0]) - s7 = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + 7, HEAD_DIM + h0]) - - # Max via tree of pl.maximum. - m01 = pl.maximum(s0, s1) - m23 = pl.maximum(s2, s3) - m45 = pl.maximum(s4, s5) - m67 = pl.maximum(s6, s7) - m0123 = pl.maximum(m01, m23) - m4567 = pl.maximum(m45, m67) - s_max = pl.maximum(m0123, m4567) - - # Exp(s - max) tree. - e0 = pl.exp(pl.sub(s0, s_max)) - e1 = pl.exp(pl.sub(s1, s_max)) - e2 = pl.exp(pl.sub(s2, s_max)) - e3 = pl.exp(pl.sub(s3, s_max)) - e4 = pl.exp(pl.sub(s4, s_max)) - e5 = pl.exp(pl.sub(s5, s_max)) - e6 = pl.exp(pl.sub(s6, s_max)) - e7 = pl.exp(pl.sub(s7, s_max)) - - es01 = pl.add(e0, e1) - es23 = pl.add(e2, e3) - es45 = pl.add(e4, e5) - es67 = pl.add(e6, e7) - es0123 = pl.add(es01, es23) - es4567 = pl.add(es45, es67) - e_sum = pl.add(es0123, es4567) - - # Weighted kv tree. - kv_s0 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 0, h0]) - kv_s1 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 1, h0]) - kv_s2 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 2, h0]) - kv_s3 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 3, h0]) - kv_s4 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 4, HEAD_DIM + h0]) - kv_s5 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 5, HEAD_DIM + h0]) - kv_s6 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 6, HEAD_DIM + h0]) - kv_s7 = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + 7, HEAD_DIM + h0]) - - w0 = pl.mul(e0, kv_s0) - w1 = pl.mul(e1, kv_s1) - w2 = pl.mul(e2, kv_s2) - w3 = pl.mul(e3, kv_s3) - w4 = pl.mul(e4, kv_s4) - w5 = pl.mul(e5, kv_s5) - w6 = pl.mul(e6, kv_s6) - w7 = pl.mul(e7, kv_s7) - - ws01 = pl.add(w0, w1) - ws23 = pl.add(w2, w3) - ws45 = pl.add(w4, w5) - ws67 = pl.add(w6, w7) - ws0123 = pl.add(ws01, ws23) - ws4567 = pl.add(ws45, ws67) - pooled_acc = pl.add(ws0123, ws4567) - - pooled_chunk = pl.div(pooled_acc, e_sum) - pooled = pl.assemble(pooled, pooled_chunk, [b_idx, h0]) - - # Block 7 (Vector): shift state down -- state[:, :ratio] = state[:, ratio:] - for b_sh in pl.parallel(0, B, 1): - row_sh = b_sh * STATE_LEN - with pl.at(level=pl.Level.CORE_GROUP, name_hint="state_shift"): - kv_src = pl.slice(kv_state_per_row, [COMPRESS_RATIO, OUT_DIM], [row_sh + COMPRESS_RATIO, 0]) - kv_state_per_row = pl.assemble(kv_state_per_row, kv_src, [row_sh, 0]) - sc_src = pl.slice(score_state_per_row, [COMPRESS_RATIO, OUT_DIM], [row_sh + COMPRESS_RATIO, 0]) - score_state_per_row = pl.assemble(score_state_per_row, sc_src, [row_sh, 0]) - - # Reshape state back to 3D - kv_state = pl.reshape(kv_state_per_row, [B, STATE_LEN, OUT_DIM]) - score_state = pl.reshape(score_state_per_row, [B, STATE_LEN, OUT_DIM]) - - # Block 8 (Vector): RMSNorm pooled with norm_w over HEAD_DIM. - normed_pooled = pl.create_tensor([B, HEAD_DIM], dtype=pl.FP32) - norm_w_2d = pl.reshape(norm_w, [1, HEAD_DIM]) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): - partial_sq = pl.full([1, B], dtype=pl.FP32, value=0.0) - for hb in pl.range(HEAD_BLOCKS): - h0 = hb * HEAD_CHUNK - pc = pl.slice(pooled, [B, HEAD_CHUNK], [0, h0]) - partial_sq = pl.add( - partial_sq, - pl.reshape(pl.row_sum(pl.mul(pc, pc)), [1, B]), - ) - inv_rms = pl.reshape( - pl.recip(pl.sqrt(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS))), - [B, 1], - ) - for hb in pl.range(HEAD_BLOCKS): - h0 = hb * HEAD_CHUNK - nc = pl.slice(pooled, [B, HEAD_CHUNK], [0, h0]) - nw_chunk = pl.cast( - pl.slice(norm_w_2d, [1, HEAD_CHUNK], [0, h0]), - target_type=pl.FP32, - ) - normed = pl.col_expand_mul(pl.row_expand_mul(nc, inv_rms), nw_chunk) - normed_pooled = pl.assemble(normed_pooled, normed, [0, h0]) - - # Block 11a (Vector): cast non-rope range to BF16 and store to out. - with pl.at(level=pl.Level.CORE_GROUP, name_hint="store_nope"): - nope_chunk = pl.slice(normed_pooled, [B, NOPE_HEAD_DIM], [0, 0]) - out = pl.assemble(out, pl.cast(nope_chunk, target_type=pl.BF16), [0, 0]) - - # Block 9 + 11b (Vector): half-vector RoPE on the last ROPE_HEAD_DIM cols, then store. - HALF_RD = ROPE_HEAD_DIM // 2 - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_store"): - x_lo = pl.slice(normed_pooled, [B, HALF_RD], [0, NOPE_HEAD_DIM]) - x_hi = pl.slice(normed_pooled, [B, HALF_RD], [0, NOPE_HEAD_DIM + HALF_RD]) - cos_fp32 = pl.cast(cos, target_type=pl.FP32) - sin_fp32 = pl.cast(sin, target_type=pl.FP32) - y_lo = pl.sub(pl.col_expand_mul(x_lo, cos_fp32), pl.col_expand_mul(x_hi, sin_fp32)) - y_hi = pl.add(pl.col_expand_mul(x_lo, sin_fp32), pl.col_expand_mul(x_hi, cos_fp32)) - out = pl.assemble(out, pl.cast(y_lo, target_type=pl.BF16), [0, NOPE_HEAD_DIM]) - out = pl.assemble(out, pl.cast(y_hi, target_type=pl.BF16), [0, NOPE_HEAD_DIM + HALF_RD]) - - return out + last_col0 = (STATE_LEN - 1) * OUT_DIM + HEAD_DIM + h0 + mi = score_state_flat[:, last_col0 : last_col0 + HEAD_CHUNK] + li = pl.exp(pl.sub(mi, mi)) + oi = kv_state_flat[:, last_col0 : last_col0 + HEAD_CHUNK] + + for s in pl.range(0, COMPRESS_RATIO): + front_col0 = s * OUT_DIM + h0 + front_score = score_state_flat[:, front_col0 : front_col0 + HEAD_CHUNK] + front_kv = kv_state_flat[:, front_col0 : front_col0 + HEAD_CHUNK] + mi_next_front = pl.maximum(mi, front_score) + alpha_front = pl.exp(pl.sub(mi, mi_next_front)) + beta_front = pl.exp(pl.sub(front_score, mi_next_front)) + li = pl.add(pl.mul(alpha_front, li), beta_front) + oi = pl.add(pl.mul(oi, alpha_front), pl.mul(front_kv, beta_front)) + mi = mi_next_front + + for s in pl.range(COMPRESS_RATIO, STATE_LEN - 1): + back_col0 = s * OUT_DIM + HEAD_DIM + h0 + back_score = score_state_flat[:, back_col0 : back_col0 + HEAD_CHUNK] + back_kv = kv_state_flat[:, back_col0 : back_col0 + HEAD_CHUNK] + mi_next_back = pl.maximum(mi, back_score) + alpha_back = pl.exp(pl.sub(mi, mi_next_back)) + beta_back = pl.exp(pl.sub(back_score, mi_next_back)) + li = pl.add(pl.mul(alpha_back, li), beta_back) + oi = pl.add(pl.mul(oi, alpha_back), pl.mul(back_kv, beta_back)) + mi = mi_next_back + + pooled_chunk = pl.div(oi, li) + pooled_kv = pl.assemble(pooled_kv, pooled_chunk, [0, h0]) + + for s in pl.parallel(0, COMPRESS_RATIO, 1): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="state_shift"): + src_col0 = (COMPRESS_RATIO + s) * OUT_DIM + dst_col0 = s * OUT_DIM + for o0 in pl.range(0, OUT_DIM, OUT_CHUNK): + dep_col0 = o0 % HEAD_DIM + dep_tile = pooled_kv[:, dep_col0 : dep_col0 + OUT_CHUNK] + dep_zero = pl.sub(dep_tile, dep_tile) + kv_tile = kv_state_flat[:, src_col0 + o0 : src_col0 + o0 + OUT_CHUNK] + score_tile = score_state_flat[:, src_col0 + o0 : src_col0 + o0 + OUT_CHUNK] + kv_tile = pl.add(kv_tile, dep_zero) + score_tile = pl.add(score_tile, dep_zero) + kv_state_flat = pl.assemble(kv_state_flat, kv_tile, [0, dst_col0 + o0]) + score_state_flat = pl.assemble(score_state_flat, score_tile, [0, dst_col0 + o0]) + + norm_w_2d = pl.reshape(norm_w, [1, HEAD_DIM]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): + partial_sq = pl.full([1, B * S], dtype=pl.FP32, value=0.0) + for k0 in pl.range(0, HEAD_DIM, HEAD_CHUNK): + # Golden applies rmsnorm to kv.to(torch.bfloat16), then casts to FP32 inside rmsnorm. + kv_rms_chunk = pl.cast( + pl.cast(pooled_kv[:, k0 : k0 + HEAD_CHUNK], target_type=pl.BF16), + target_type=pl.FP32, + ) + partial_sq = pl.add( + partial_sq, + pl.reshape(pl.row_sum(pl.mul(kv_rms_chunk, kv_rms_chunk)), [1, B * S]), + ) + + variance = pl.reshape(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS), [B * S, 1]) + inv_rms = pl.recip(pl.sqrt(variance)) + for k0 in pl.range(0, HEAD_DIM, HEAD_CHUNK): + kv_norm_chunk = pl.cast( + pl.cast(pooled_kv[:, k0 : k0 + HEAD_CHUNK], target_type=pl.BF16), + target_type=pl.FP32, + ) + gamma = norm_w_2d[:, k0 : k0 + HEAD_CHUNK] + normed_chunk = pl.col_expand_mul(pl.row_expand_mul(kv_norm_chunk, inv_rms), gamma) + normed_kv = pl.assemble(normed_kv, pl.cast(normed_chunk, target_type=pl.BF16), [0, k0]) + + kv_rope = pl.create_tensor([B * S, ROPE_HEAD_DIM], dtype=pl.BF16) + kv_proj_even = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.FP32) + kv_proj_odd = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.FP32) + rope_even = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.BF16) + rope_odd = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.BF16) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_slice"): + kv_rope = pl.assemble(kv_rope, normed_kv[:, NOPE_HEAD_DIM : HEAD_DIM], [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_slice"): + kv_rope_tile = kv_rope[:, 0 : ROPE_CHUCK] + even_select_tile = even_select[0 : ROPE_CHUCK, :] + odd_select_tile = odd_select[0 : ROPE_CHUCK, :] + even_acc = pl.matmul(kv_rope_tile, even_select_tile, out_dtype=pl.FP32) + odd_acc = pl.matmul(kv_rope_tile, odd_select_tile, out_dtype=pl.FP32) + + for r0 in pl.range(ROPE_CHUCK, ROPE_HEAD_DIM, ROPE_CHUCK): + kv_rope_tile = kv_rope[:, r0 : r0 + ROPE_CHUCK] + even_select_tile = even_select[r0 : r0 + ROPE_CHUCK, :] + odd_select_tile = odd_select[r0 : r0 + ROPE_CHUCK, :] + even_acc = pl.matmul_acc(even_acc, kv_rope_tile, even_select_tile) + odd_acc = pl.matmul_acc(odd_acc, kv_rope_tile, odd_select_tile) + kv_proj_even = pl.assemble(kv_proj_even, even_acc, [0, 0]) + kv_proj_odd = pl.assemble(kv_proj_odd, odd_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_apply"): + even_tile = kv_proj_even[:, :] + odd_tile = kv_proj_odd[:, :] + rope_even_acc = pl.cast(pl.sub(pl.col_expand_mul(even_tile, cos), pl.col_expand_mul(odd_tile, sin)), target_type=pl.BF16) + rope_odd_acc = pl.cast(pl.add(pl.col_expand_mul(even_tile, sin), pl.col_expand_mul(odd_tile, cos)), target_type=pl.BF16) + rope_even = pl.assemble(rope_even, rope_even_acc, [0, 0]) + rope_odd = pl.assemble(rope_odd, rope_odd_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_assemble"): + rope_even_tile = rope_even[:, 0 : ROPE_CHUCK] + rope_odd_tile = rope_odd[:, 0 : ROPE_CHUCK] + even_select_tile_t = even_select[:, 0 : ROPE_CHUCK] + odd_select_tile_t = odd_select[:, 0 : ROPE_CHUCK] + rope_acc = pl.matmul(rope_even_tile, even_select_tile_t, out_dtype=pl.FP32, b_trans=True) + rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) + + for r0 in pl.range(ROPE_CHUCK, ROPE_HEAD_DIM // 2, ROPE_CHUCK): + rope_even_tile = rope_even[:, r0 : r0 + ROPE_CHUCK] + rope_odd_tile = rope_odd[:, r0 : r0 + ROPE_CHUCK] + even_select_tile_t = even_select[:, r0 : r0 + ROPE_CHUCK] + odd_select_tile_t = odd_select[:, r0 : r0 + ROPE_CHUCK] + rope_acc = pl.matmul_acc(rope_acc, rope_even_tile, even_select_tile_t, b_trans=True) + rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_write"): + normed_kv = pl.assemble(normed_kv, pl.cast(rope_acc, target_type=pl.BF16), [0, NOPE_HEAD_DIM]) + + if rotate: + for o0 in pl.range(0, HEAD_DIM, OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_hadamard"): + kv_proj_tile = normed_kv[:, 0 : HEAD_DIM] + hadamard_tile = hadamard[0 : HEAD_DIM, o0 : o0 + OUT_CHUNK] + kv_hadamard_acc = pl.matmul(kv_proj_tile, hadamard_tile, out_dtype=pl.FP32) + kv_flat = pl.assemble(kv_flat, kv_hadamard_acc, [0, o0]) + else: + for o0 in pl.parallel(0, HEAD_DIM, OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_write"): + kv_out_tile = normed_kv[:, o0 : o0 + OUT_CHUNK] + kv_flat = pl.assemble(kv_flat, pl.cast(kv_out_tile, target_type=pl.FP32), [0, o0]) + + kv_cache_flat = pl.reshape(kv_cache, [B * IDX_KV_LEN, HEAD_DIM]) + cache_col = start_pos // COMPRESS_RATIO + for b_idx in pl.parallel(B): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_cache_write"): + cache_row = b_idx * IDX_KV_LEN + cache_col + cache_kv_row = kv_flat[b_idx : b_idx + 1, 0 : HEAD_DIM] + kv_cache_flat = pl.assemble(kv_cache_flat, pl.cast(cache_kv_row, target_type=pl.BF16), [cache_row, 0]) + kv_cache = pl.reshape(kv_cache_flat, [B, IDX_KV_LEN, HEAD_DIM]) + + kv_state = pl.reshape(kv_state_flat, [B, STATE_LEN, OUT_DIM]) + score_state = pl.reshape(score_state_flat, [B, STATE_LEN, OUT_DIM]) + kv = pl.reshape(kv_flat, [B, S, HEAD_DIM]) + return kv, kv_state, score_state, kv_cache @pl.jit def compressor_test( x: pl.Tensor[[B, S, D], pl.BF16], + kv: pl.Out[pl.Tensor[[B, S, HEAD_DIM], pl.FP32]], kv_state: pl.Out[pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32]], score_state: pl.Out[pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32]], - wkv: pl.Tensor[[OUT_DIM, D], pl.BF16], - wgate: pl.Tensor[[OUT_DIM, D], pl.BF16], + wkv: pl.Tensor[[D, OUT_DIM], pl.BF16], + wgate: pl.Tensor[[D, OUT_DIM], pl.BF16], ape: pl.Tensor[[COMPRESS_RATIO, OUT_DIM], pl.FP32], - norm_w: pl.Tensor[[HEAD_DIM], pl.BF16], - cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], - sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], + norm_w: pl.Tensor[[HEAD_DIM], pl.FP32], + cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + even_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + odd_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], hadamard: pl.Tensor[[HEAD_DIM, HEAD_DIM], pl.BF16], + kv_cache: pl.Out[pl.Tensor[[B, IDX_KV_LEN, HEAD_DIM], pl.BF16]], start_pos: pl.Scalar[pl.INT32], - out: pl.Out[pl.Tensor[[B, HEAD_DIM], pl.BF16]], + rotate: pl.Scalar[pl.BOOL], ): - out = compressor( - x, kv_state, score_state, wkv, wgate, ape, norm_w, cos, sin, hadamard, start_pos, out, + kv, kv_state, score_state, kv_cache = compressor( + x, kv, kv_state, score_state, wkv, wgate, ape, norm_w, cos, sin, even_select, odd_select, hadamard, kv_cache, start_pos, rotate ) - return out + return kv, kv_state, score_state, kv_cache def golden_compressor(tensors): """Torch reference for Compressor.forward (decode branch, ratio=4 overlap).""" import torch - x = tensors["x"] + x = tensors["x"].float() kv_state = tensors["kv_state"] score_state = tensors["score_state"] wkv = tensors["wkv"].float() wgate = tensors["wgate"].float() - ape = tensors["ape"].float() - norm_w = tensors["norm_w"].float() - cos = tensors["cos"].float() - sin = tensors["sin"].float() + ape = tensors["ape"] + norm_w = tensors["norm_w"] + cos = tensors["cos"] + sin = tensors["sin"] hadamard = tensors["hadamard"].float() - + kv_cache = tensors["kv_cache"] + start_pos = int(tensors["start_pos"]) + rotate = bool(tensors["rotate"]) bsz, _, _ = x.shape - ratio, overlap, rotate, d, rd = COMPRESS_RATIO, OVERLAP, ROTATE, HEAD_DIM, ROPE_HEAD_DIM - dtype = x.dtype - x = x.float() - kv = x.view(bsz, -1) @ wkv.T - score = x.view(bsz, -1) @ wgate.T - - should_compress = (START_POS + 1) % ratio == 0 - score = score + ape[START_POS % ratio] - if overlap: - kv_state[:bsz, ratio + START_POS % ratio] = kv - score_state[:bsz, ratio + START_POS % ratio] = score - if should_compress: - kvs = torch.cat([kv_state[:bsz, :ratio, :d], kv_state[:bsz, ratio:, d:]], dim=1) - scs = torch.cat([score_state[:bsz, :ratio, :d], score_state[:bsz, ratio:, d:]], dim=1) - kv = (kvs * scs.softmax(dim=1)).sum(dim=1, keepdim=True) - kv_state[:bsz, :ratio] = kv_state[:bsz, ratio:] - score_state[:bsz, :ratio] = score_state[:bsz, ratio:] + ratio, d, rd = COMPRESS_RATIO, HEAD_DIM, ROPE_HEAD_DIM + + kv = x @ wkv + score = x @ wgate + + should_compress = (start_pos + 1) % ratio == 0 + score = score + ape[start_pos % ratio] + + kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1) + score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1) + if should_compress: + kvs = torch.cat([kv_state[:bsz, :ratio, :d], kv_state[:bsz, ratio:, d:]], dim=1) + scs = torch.cat([score_state[:bsz, :ratio, :d], score_state[:bsz, ratio:, d:]], dim=1) + kv = (kvs * scs.softmax(dim=1)).sum(dim=1, keepdim=True) + kv_state[:bsz, :ratio] = kv_state[:bsz, ratio:] + score_state[:bsz, :ratio] = score_state[:bsz, ratio:] + + tensors["kv_state"][:] = kv_state + tensors["score_state"][:] = score_state if not should_compress: - tensors["out"][:] = torch.zeros(B, HEAD_DIM, dtype=torch.bfloat16) return - kv_c = kv.squeeze(1) - kv_c = kv_c * torch.rsqrt(kv_c.square().mean(-1, keepdim=True) + EPS) * norm_w + def rmsnorm(x, w): + x = x.float() + var = x.square().mean(-1, keepdim=True) + x = x * torch.rsqrt(var + EPS) + return (w * x).to(torch.bfloat16) - half_rd = rd // 2 - x_lo = kv_c[..., -rd:-half_rd] - x_hi = kv_c[..., -half_rd:] + kv = rmsnorm(kv.to(torch.bfloat16), norm_w) + + x_pair = kv[..., -rd:].unflatten(-1, (-1, 2)) + x0, x1 = x_pair[..., 0], x_pair[..., 1] cos_v, sin_v = cos.view(-1), sin.view(-1) - y_lo = x_lo * cos_v - x_hi * sin_v - y_hi = x_lo * sin_v + x_hi * cos_v - kv_c = torch.cat([kv_c[..., :-rd], y_lo, y_hi], dim=-1) + y0 = (x0 * cos_v - x1 * sin_v).to(torch.bfloat16) + y1 = (x0 * sin_v + x1 * cos_v).to(torch.bfloat16) + + kv = torch.cat([kv[..., :-rd], torch.stack([y0, y1], dim=-1).flatten(-2)], dim=-1).float() if rotate: - kv_c = (kv_c @ hadamard).to(torch.bfloat16).float() - else: - pass + kv = kv @ hadamard + tensors["kv"][:] = kv + + kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1) - tensors["out"][:] = kv_c.to(torch.bfloat16) + tensors["kv_cache"][:] = kv_cache def build_tensor_specs(): import torch # type: ignore[import] from golden import ScalarSpec, TensorSpec - torch.manual_seed(42) - def init_x(): - return torch.randn(B, S, D) - 0.5 + return torch.randn(B, S, D) * 0.1 def init_kv_state(): return torch.zeros(B, STATE_LEN, OUT_DIM) def init_score_state(): - return torch.full((B, STATE_LEN, OUT_DIM), float("-inf")) + return torch.full((B, STATE_LEN, OUT_DIM), FP32_NEG_INF) def init_wkv(): - return (torch.randn(OUT_DIM, D) - 0.5) / (D ** 0.5) + return torch.randn(D, OUT_DIM) / D ** 0.5 def init_wgate(): - return (torch.randn(OUT_DIM, D) - 0.5) / (D ** 0.5) + return torch.randn(D, OUT_DIM) / D ** 0.5 def init_ape(): - return torch.randn(COMPRESS_RATIO, OUT_DIM) * 0.01 + return torch.randn(COMPRESS_RATIO, OUT_DIM) * 0.1 def init_norm_w(): return torch.ones(HEAD_DIM) def init_cos(): return torch.cos(torch.arange(ROPE_HEAD_DIM // 2).reshape(1, ROPE_HEAD_DIM // 2) * 1e-3) def init_sin(): return torch.sin(torch.arange(ROPE_HEAD_DIM // 2).reshape(1, ROPE_HEAD_DIM // 2) * 1e-3) + def init_odd_select(): + M = torch.zeros((ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2)) + for i in range(ROPE_HEAD_DIM // 2): + M[2*i+1, i] = 1 + return M + def init_even_select(): + M = torch.zeros((ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2)) + for i in range(ROPE_HEAD_DIM // 2): + M[2*i, i] = 1 + return M def init_hadamard(): - return torch.eye(HEAD_DIM) + H = torch.ones((1, 1)) + while H.shape[0] < HEAD_DIM: + H = torch.cat([ + torch.cat([H, H], dim=1), + torch.cat([H, -H], dim=1), + ], dim=0) + return H / (HEAD_DIM ** 0.5) + def init_kv_cache(): + return torch.zeros(B, IDX_KV_LEN, HEAD_DIM) + return [ TensorSpec("x", [B, S, D], torch.bfloat16, init_value=init_x), + TensorSpec("kv", [B, S, HEAD_DIM], torch.float32, is_output=True), TensorSpec("kv_state", [B, STATE_LEN, OUT_DIM], torch.float32, init_value=init_kv_state, is_output=True), TensorSpec("score_state", [B, STATE_LEN, OUT_DIM], torch.float32, init_value=init_score_state, is_output=True), - TensorSpec("wkv", [OUT_DIM, D], torch.bfloat16, init_value=init_wkv), - TensorSpec("wgate", [OUT_DIM, D], torch.bfloat16, init_value=init_wgate), + TensorSpec("wkv", [D, OUT_DIM], torch.bfloat16, init_value=init_wkv), + TensorSpec("wgate", [D, OUT_DIM], torch.bfloat16, init_value=init_wgate), TensorSpec("ape", [COMPRESS_RATIO, OUT_DIM], torch.float32, init_value=init_ape), - TensorSpec("norm_w", [HEAD_DIM], torch.bfloat16, init_value=init_norm_w), - TensorSpec("cos", [1, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_cos), - TensorSpec("sin", [1, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_sin), + TensorSpec("norm_w", [HEAD_DIM], torch.float32, init_value=init_norm_w), + TensorSpec("cos", [1, ROPE_HEAD_DIM // 2], torch.float32, init_value=init_cos), + TensorSpec("sin", [1, ROPE_HEAD_DIM // 2], torch.float32, init_value=init_sin), + TensorSpec("even_select", [ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_even_select), + TensorSpec("odd_select", [ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_odd_select), TensorSpec("hadamard", [HEAD_DIM, HEAD_DIM], torch.bfloat16, init_value=init_hadamard), + TensorSpec("kv_cache", [B, IDX_KV_LEN, HEAD_DIM], torch.bfloat16, init_value=init_kv_cache, is_output=True), ScalarSpec("start_pos", torch.int32, START_POS), - TensorSpec("out", [B, HEAD_DIM], torch.bfloat16, is_output=True), + ScalarSpec("rotate", torch.bool, ROTATE), ] @@ -385,8 +442,8 @@ def init_hadamard(): specs=build_tensor_specs(), golden_fn=golden_compressor, config=RunConfig( - rtol=4e-3, - atol=4e-3, + rtol=1e-3, + atol=1e-3, compile=dict(dump_passes=True), runtime=dict( platform=args.platform, diff --git a/models/deepseek/v4/indexer.py b/models/deepseek/v4/indexer.py new file mode 100644 index 0000000..0246fa1 --- /dev/null +++ b/models/deepseek/v4/indexer.py @@ -0,0 +1,436 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""DeepSeek-V4 Indexer (decode). Mirrors model.py Indexer (line 380-433); +golden is a port of forward's decode branch (prefill `start_pos == 0` path is omitted). +The inner Compressor is invoked via golden_compressor (placeholder).""" + + +import pypto.language as pl + +B = 16 # demo 4 +S = 1 +T = B * S + +D = 4096 # flash:4096 pro:7168 +Q_LORA = 1024 # flash:1024 pro:1536 +ROPE_HEAD_DIM = 64 + +IDX_N_HEADS = 64 +IDX_HEAD_DIM = 128 +IDX_NOPE_HEAD_DIM = IDX_HEAD_DIM - ROPE_HEAD_DIM +IDX_TOPK = 16 # v4-pro 1024 +IDX_SOFTMAX_SCALE = IDX_HEAD_DIM ** -0.5 +WEIGHTS_SCALE = IDX_SOFTMAX_SCALE * IDX_N_HEADS ** -0.5 + +COMPRESS_RATIO = 4 + +MAX_SEQ_LEN = 4096 +IDX_KV_LEN = MAX_SEQ_LEN // COMPRESS_RATIO +SCORE_LEN = IDX_KV_LEN +SORT_LEN = 2048 +CACHE_TILE = 32 +MAX_CACHE_BLOCKS = SCORE_LEN // CACHE_TILE +FP32_NEG_INF = -3.4028234663852886e38 + +START_POS = 256 # default for ScalarSpec; >0 (decode) and (START_POS+1)%COMPRESS_RATIO==0 to cover the full inner-compressor path +OFFSET = 128 # default for ScalarSpec; = win in attention orch; added to topk_idxs (model.py:432) + +Q_CHUCK = 128 +Q_OUT_CHUCK = 128 +ROPE_CHUCK = 16 +HEAD_DIM_CHUCK = 32 +D_CHUCK = 32 +HEAD_CHUCK = 16 + +@pl.jit.inline +def indexer( + x: pl.Tensor[[B, S, D], pl.BF16], + qr: pl.Tensor[[B, S, Q_LORA], pl.BF16], + wq_b: pl.Tensor[[Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM], pl.BF16], + weights_proj: pl.Tensor[[D, IDX_N_HEADS], pl.BF16], + cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], # caller passes freqs_cis[start_pos] + sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + even_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + odd_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + hadamard: pl.Tensor[[IDX_HEAD_DIM, IDX_HEAD_DIM], pl.BF16], # shared by q rotation and inner Compressor + idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16], + score: pl.Tensor[[B, S, SCORE_LEN], pl.FP32], + topk_idxs: pl.Tensor[[B, S, SCORE_LEN], pl.INT32], + start_pos: pl.Scalar[pl.INT32], # decode step; varies per call + offset: pl.Scalar[pl.INT32], # added to topk_idxs (= win from attention orch) +): + # TODO: kernel implementation + cache_len = (start_pos + S) // COMPRESS_RATIO + cache_blocks = (cache_len + CACHE_TILE - 1) // CACHE_TILE + + qr_proj = pl.create_tensor([T, IDX_N_HEADS * IDX_HEAD_DIM], dtype=pl.BF16) + qr_flat = pl.reshape(qr, [T, Q_LORA]) + for o0 in pl.parallel(0, IDX_N_HEADS * IDX_HEAD_DIM, Q_OUT_CHUCK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_proj"): + qr_tile = qr_flat[:, 0 : Q_CHUCK] + wq_tile = wq_b[0 : Q_CHUCK, o0 : o0 + Q_OUT_CHUCK] + qr_acc = pl.matmul(qr_tile, wq_tile, out_dtype=pl.FP32) + + for q0 in pl.range(Q_CHUCK, Q_LORA, Q_CHUCK): + qr_tile = qr_flat[:, q0 : q0 + Q_CHUCK] + wq_tile = wq_b[q0 : q0 + Q_CHUCK, o0 : o0 + Q_OUT_CHUCK] + qr_acc =pl.matmul_acc(qr_acc, qr_tile, wq_tile) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_proj_write"): + qr_proj = pl.assemble(qr_proj, pl.cast(qr_acc, target_type=pl.BF16), [0, o0]) + + qr_proj_flat = pl.reshape(qr_proj, [T * IDX_N_HEADS, IDX_HEAD_DIM]) + qr_rope = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM], dtype=pl.BF16) + qr_proj_even = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM // 2], dtype=pl.FP32) + qr_proj_odd = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM // 2], dtype=pl.FP32) + rope_even = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM // 2], dtype=pl.BF16) + rope_odd = pl.create_tensor([T * IDX_N_HEADS, ROPE_HEAD_DIM // 2], dtype=pl.BF16) + qr_hadamard = pl.create_tensor([T * IDX_N_HEADS, IDX_HEAD_DIM], dtype=pl.BF16) + + for o0 in pl.parallel(0, T * IDX_N_HEADS, IDX_N_HEADS): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_slice"): + qr_rope = pl.assemble(qr_rope, qr_proj_flat[o0 : o0 + IDX_N_HEADS, IDX_NOPE_HEAD_DIM :], [o0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_slice"): + qr_proj_rope_tile = qr_rope[o0 : o0 + IDX_N_HEADS, 0 : ROPE_CHUCK] + even_select_tile = even_select[0 : ROPE_CHUCK, :] + odd_select_tile = odd_select[0 : ROPE_CHUCK, :] + even_acc = pl.matmul(qr_proj_rope_tile, even_select_tile, out_dtype=pl.FP32) + odd_acc = pl.matmul(qr_proj_rope_tile, odd_select_tile, out_dtype=pl.FP32) + + for r0 in pl.range(ROPE_CHUCK, ROPE_HEAD_DIM, ROPE_CHUCK): + qr_proj_rope_tile = qr_rope[o0 : o0 + IDX_N_HEADS, r0 : r0 + ROPE_CHUCK] + even_select_tile = even_select[r0 : r0 + ROPE_CHUCK, :] + odd_select_tile = odd_select[r0 : r0 + ROPE_CHUCK, :] + even_acc = pl.matmul_acc(even_acc, qr_proj_rope_tile, even_select_tile) + odd_acc = pl.matmul_acc(odd_acc, qr_proj_rope_tile, odd_select_tile) + qr_proj_even = pl.assemble(qr_proj_even, even_acc, [o0, 0]) + qr_proj_odd = pl.assemble(qr_proj_odd, odd_acc, [o0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_apply"): + even_tile = qr_proj_even[o0 : o0 + IDX_N_HEADS, :] + odd_tile = qr_proj_odd[o0 : o0 + IDX_N_HEADS, :] + rope_even_acc = pl.cast(pl.sub(pl.col_expand_mul(even_tile, cos), pl.col_expand_mul(odd_tile, sin)), target_type=pl.BF16) + rope_odd_acc = pl.cast(pl.add(pl.col_expand_mul(even_tile, sin), pl.col_expand_mul(odd_tile, cos)), target_type=pl.BF16) + rope_even = pl.assemble(rope_even, rope_even_acc, [o0, 0]) + rope_odd = pl.assemble(rope_odd, rope_odd_acc, [o0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_assemble"): + rope_even_tile = rope_even[o0 : o0 + IDX_N_HEADS, 0 : ROPE_CHUCK] + rope_odd_tile = rope_odd[o0 : o0 + IDX_N_HEADS, 0 : ROPE_CHUCK] + even_select_tile_t = even_select[:, 0 : ROPE_CHUCK] + odd_select_tile_t = odd_select[:, 0 : ROPE_CHUCK] + rope_acc = pl.matmul(rope_even_tile, even_select_tile_t, out_dtype=pl.FP32, b_trans=True) + rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) + + for r0 in pl.range(ROPE_CHUCK, ROPE_HEAD_DIM // 2, ROPE_CHUCK): + rope_even_tile = rope_even[o0 : o0 + IDX_N_HEADS, r0 : r0 + ROPE_CHUCK] + rope_odd_tile = rope_odd[o0 : o0 + IDX_N_HEADS, r0 : r0 + ROPE_CHUCK] + even_select_tile_t = even_select[:, r0 : r0 + ROPE_CHUCK] + odd_select_tile_t = odd_select[:, r0 : r0 + ROPE_CHUCK] + rope_acc = pl.matmul_acc(rope_acc, rope_even_tile, even_select_tile_t, b_trans=True) + rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_write"): + qr_rope = pl.assemble(qr_rope, pl.cast(rope_acc, target_type=pl.BF16), [o0, 0]) + + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_assemble"): + qr_proj_flat = pl.assemble(qr_proj_flat, qr_rope[o0 : o0 + IDX_N_HEADS, :], [o0, IDX_NOPE_HEAD_DIM]) + + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_hadamard"): + qr_proj_tile = qr_proj_flat[o0 : o0 + IDX_N_HEADS, 0 : HEAD_DIM_CHUCK] + hadamard_tile = hadamard[0 : HEAD_DIM_CHUCK, :] + qr_hadamard_acc = pl.matmul(qr_proj_tile, hadamard_tile, out_dtype=pl.FP32) + + for h0 in pl.range(HEAD_DIM_CHUCK, IDX_HEAD_DIM, HEAD_DIM_CHUCK): + qr_proj_tile = qr_proj_flat[o0 : o0 + IDX_N_HEADS, h0 : h0 + HEAD_DIM_CHUCK] + hadamard_tile = hadamard[h0 : h0 + HEAD_DIM_CHUCK, :] + qr_hadamard_acc = pl.matmul_acc(qr_hadamard_acc, qr_proj_tile, hadamard_tile) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_hadamard_write"): + qr_hadamard = pl.assemble(qr_hadamard, pl.cast(qr_hadamard_acc, target_type=pl.BF16), [o0, 0]) + + + weights = pl.create_tensor([T, IDX_N_HEADS], dtype=pl.FP32) + x_flat = pl.reshape(x, [T, D]) + for h0 in pl.parallel(0, IDX_N_HEADS, HEAD_CHUCK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="weights_proj"): + x_tile = x_flat[:, 0 : D_CHUCK] + weights_proj_tile = weights_proj[0 : D_CHUCK, h0 : h0 + HEAD_CHUCK] + weights_acc = pl.matmul(x_tile, weights_proj_tile, out_dtype=pl.FP32) + + for d0 in pl.range(D_CHUCK, D, D_CHUCK): + x_tile = x_flat[:, d0 : d0 + D_CHUCK] + weights_proj_tile = weights_proj[d0 : d0 + D_CHUCK, h0 : h0 + HEAD_CHUCK] + weights_acc = pl.matmul_acc(weights_acc, x_tile, weights_proj_tile) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="weights_write"): + weights_scale = pl.mul(weights_acc, WEIGHTS_SCALE) + weights = pl.assemble(weights, weights_scale, [0, h0]) + + + kv_cache_flat = pl.reshape(idx_kv_cache, [B * IDX_KV_LEN, IDX_HEAD_DIM]) + score_logits = pl.create_tensor([B * MAX_CACHE_BLOCKS * CACHE_TILE, IDX_N_HEADS], dtype=pl.FP32) + weighted_score_tiles = pl.create_tensor([B * MAX_CACHE_BLOCKS * S, CACHE_TILE], dtype=pl.FP32) + score_flat = pl.reshape(score, [T, SCORE_LEN]) + + for b in pl.parallel(B): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="score"): + q0 = b * S * IDX_N_HEADS + kv0 = b * IDX_KV_LEN + for cb in pl.range(cache_blocks): + qr_hadamard_tile = qr_hadamard[q0 : q0 + S * IDX_N_HEADS, :] + kv_cache_tile = kv_cache_flat[kv0 + cb * CACHE_TILE : kv0 + (cb + 1) * CACHE_TILE, :] + score_logits_tile = pl.matmul(kv_cache_tile, qr_hadamard_tile, out_dtype=pl.FP32, b_trans=True) + score_logits = pl.assemble( + score_logits, + score_logits_tile, + [(b * MAX_CACHE_BLOCKS + cb) * CACHE_TILE, 0], + ) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_weighted_reduce"): + t0 = b * S + neg_inf_score = pl.full([S, SCORE_LEN], dtype=pl.FP32, value=FP32_NEG_INF) + score_flat = pl.assemble(score_flat, neg_inf_score, [t0, 0]) + for cb in pl.range(cache_blocks): + cache0 = cb * CACHE_TILE + valid_len = pl.min(CACHE_TILE, cache_len - cache0) + logits_row0 = (b * MAX_CACHE_BLOCKS + cb) * CACHE_TILE + score_tile = score_logits[logits_row0 : logits_row0 + CACHE_TILE, :] + relu_score = pl.maximum(score_tile, pl.mul(score_tile, 0.0)) + weights_tile = weights[t0 : t0 + S, :] + weighted_score_t = pl.col_expand_mul(relu_score, weights_tile) + weighted_score = pl.reshape(pl.row_sum(weighted_score_t), [S, CACHE_TILE]) + score_row0 = (b * MAX_CACHE_BLOCKS + cb) * S + weighted_score_tiles = pl.assemble(weighted_score_tiles, weighted_score, [score_row0, 0]) + weighted_score_valid = pl.slice( + weighted_score_tiles, + [S, CACHE_TILE], + [score_row0, 0], + valid_shape=[S, valid_len], + ) + score_flat = pl.assemble(score_flat, weighted_score_valid, [t0, cache0]) + + topk_idxs_flat = pl.reshape(topk_idxs, [T, SCORE_LEN]) + for t in pl.parallel(T): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="topk"): + offset_i32 = pl.cast(offset, target_type=pl.INT32) + score_row = score_flat[t : t + 1, :] + neg_inf_sort = pl.full([1, SORT_LEN - SCORE_LEN], dtype=pl.FP32, value=FP32_NEG_INF) + sort_row = pl.concat(score_row, neg_inf_sort) + idx_init = pl.tensor.arange(0, [1, SORT_LEN], dtype=pl.UINT32) + sorted_score_tile = pl.tensor.sort32(sort_row, idx_init) + sorted_score_tile = pl.tensor.mrgsort(sorted_score_tile, block_len=64) + sorted_score_tile = pl.tensor.mrgsort(sorted_score_tile, block_len=256) + sorted_score_tile = pl.tensor.mrgsort(sorted_score_tile, block_len=1024) + invalid_idxs = pl.full([1, SCORE_LEN], dtype=pl.INT32, value=-1) + topk_idxs_flat = pl.assemble(topk_idxs_flat, invalid_idxs, [t, 0]) + topk_pairs = sorted_score_tile[:, 0 : 2 * IDX_TOPK] + topk_idxs_tile = pl.tensor.gather(topk_pairs, mask_pattern=pl.tile.MaskPattern.P1010, output_dtype=pl.INT32) + raw_topk_idxs = pl.create_tensor([1, IDX_TOPK], dtype=pl.INT32) + raw_topk_idxs = pl.assemble(raw_topk_idxs, topk_idxs_tile, [0, 0]) + valid_topk = pl.min(IDX_TOPK, cache_len) + topk_idxs_valid = pl.slice( + raw_topk_idxs, + [1, IDX_TOPK], + [0, 0], + valid_shape=[1, valid_topk], + ) + topk_idxs_flat = pl.assemble(topk_idxs_flat, pl.add(topk_idxs_valid, offset_i32), [t, 0]) + + return topk_idxs + + +@pl.jit +def indexer_test( + x: pl.Tensor[[B, S, D], pl.BF16], + qr: pl.Tensor[[B, S, Q_LORA], pl.BF16], + wq_b: pl.Tensor[[Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM], pl.BF16], + weights_proj: pl.Tensor[[D, IDX_N_HEADS], pl.BF16], + cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + even_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + odd_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + hadamard: pl.Tensor[[IDX_HEAD_DIM, IDX_HEAD_DIM], pl.BF16], + idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16], + score: pl.Out[pl.Tensor[[B, S, SCORE_LEN], pl.FP32]], + topk_idxs: pl.Out[pl.Tensor[[B, S, SCORE_LEN], pl.INT32]], + start_pos: pl.Scalar[pl.INT32], + offset: pl.Scalar[pl.INT32], +): + topk_idxs = indexer( + x, + qr, + wq_b, + weights_proj, + cos, + sin, + even_select, + odd_select, + hadamard, + idx_kv_cache, + score, + topk_idxs, + start_pos, + offset, + ) + return topk_idxs + + +def golden_indexer(tensors): + """Torch reference for Indexer.forward (decode branch; prefill omitted; W8A8C16 quant ops are identity in golden).""" + import torch + + x = tensors["x"].float() + qr = tensors["qr"].float() + wq_b = tensors["wq_b"].float() + weights_proj = tensors["weights_proj"].float() + cos = tensors["cos"] + sin = tensors["sin"] + hadamard = tensors["hadamard"].float() + idx_kv_cache = tensors["idx_kv_cache"].float() + + start_pos = int(tensors["start_pos"]) + offset = int(tensors["offset"]) + + bsz, seqlen, _ = x.shape + ratio, rd = COMPRESS_RATIO, ROPE_HEAD_DIM + end_pos = start_pos + seqlen + + if start_pos == 0: + return + + # W8A8C16: wq_b W8 per-channel int8; qr A8 per-token int8. + q = (qr @ wq_b).view(B, S, IDX_N_HEADS, IDX_HEAD_DIM) + + x_pair = q[..., -rd:].unflatten(-1, (-1, 2)) + x0, x1 = x_pair[..., 0], x_pair[..., 1] + cos_v, sin_v = cos.view(-1), sin.view(-1) + y0 = (x0 * cos_v - x1 * sin_v).to(torch.bfloat16) + y1 = (x0 * sin_v + x1 * cos_v).to(torch.bfloat16) + + q = torch.cat([q[..., :-rd], torch.stack([y0, y1], dim=-1).flatten(-2)], dim=-1) + + q = q @ hadamard + # W8A8C16: A8 per-token-head int8 quant of q here (consumed by LI batch_matmul below). + # flash: fp4_act_quant on q (FP4 simulation). + + weights = (x @ weights_proj) * WEIGHTS_SCALE + + cache_len = end_pos // ratio + + kv_view = idx_kv_cache[:bsz, :cache_len] + # W8A8C16: LI batch_matmul Int8. q A8 per-token-head int8; kv_view (Indexer Cache) C8 per-token-head int8. + # flash: q/kv via FP4 simulation (full Hadamard rotation + fp4_act_quant). + score = torch.einsum("bshd,btd->bsht", q, kv_view) + score = (torch.relu(score) * weights.unsqueeze(-1)).sum(dim=2) + score_full = torch.full((bsz, seqlen, SCORE_LEN), FP32_NEG_INF, dtype=torch.float32) + score_full[..., :cache_len] = score.to(torch.float32) + tensors["score"][:] = score_full + + k = min(IDX_TOPK, cache_len) + _, idx = score.topk(k, dim=-1) + topk_idxs = torch.full((bsz, seqlen, SCORE_LEN), -1, dtype=torch.int32) + topk_idxs[..., :k] = idx.to(torch.int32) + topk_idxs[..., :k] += offset + + tensors["topk_idxs"][:] = topk_idxs.view(B, S, SCORE_LEN) + + +def build_tensor_specs(): + import torch # type: ignore[import] + from golden import ScalarSpec, TensorSpec + + def init_x(): + return torch.randn(B, S, D) * 0.1 + def init_qr(): + return torch.randn(B, S, Q_LORA) * 0.1 + def init_wq_b(): + return torch.randn(Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM) / Q_LORA ** 0.5 + def init_weights_proj(): + return torch.randn(D, IDX_N_HEADS) / D ** 0.5 + def init_cos(): + return torch.cos(torch.arange(ROPE_HEAD_DIM // 2).reshape(1, ROPE_HEAD_DIM // 2) * 1e-3) + def init_sin(): + return torch.sin(torch.arange(ROPE_HEAD_DIM // 2).reshape(1, ROPE_HEAD_DIM // 2) * 1e-3) + def init_odd_select(): + M = torch.zeros((ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2)) + for i in range(ROPE_HEAD_DIM // 2): + M[2*i+1, i] = 1 + return M + def init_even_select(): + M = torch.zeros((ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2)) + for i in range(ROPE_HEAD_DIM // 2): + M[2*i, i] = 1 + return M + def init_hadamard(): + H = torch.ones((1, 1)) + while H.shape[0] < IDX_HEAD_DIM: + H = torch.cat([ + torch.cat([H, H], dim=1), + torch.cat([H, -H], dim=1), + ], dim=0) + return H / (IDX_HEAD_DIM ** 0.5) + def init_idx_kv_cache(): + return torch.randn(B, IDX_KV_LEN, IDX_HEAD_DIM) + + return [ + TensorSpec("x", [B, S, D], torch.bfloat16, init_value=init_x), + TensorSpec("qr", [B, S, Q_LORA], torch.bfloat16, init_value=init_qr), + TensorSpec("wq_b", [Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM], torch.bfloat16, init_value=init_wq_b), + TensorSpec("weights_proj", [D, IDX_N_HEADS], torch.bfloat16, init_value=init_weights_proj), + TensorSpec("cos", [1, ROPE_HEAD_DIM // 2], torch.float32, init_value=init_cos), + TensorSpec("sin", [1, ROPE_HEAD_DIM // 2], torch.float32, init_value=init_sin), + TensorSpec("even_select", [ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_even_select), + TensorSpec("odd_select", [ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_odd_select), + TensorSpec("hadamard", [IDX_HEAD_DIM, IDX_HEAD_DIM], torch.bfloat16, init_value=init_hadamard), + TensorSpec("idx_kv_cache", [B, IDX_KV_LEN, IDX_HEAD_DIM], torch.bfloat16, init_value=init_idx_kv_cache), + # Outputs are fixed to SCORE_LEN; positions past cache_len are -inf for score and -1 for topk_idxs. + TensorSpec("score", [B, S, SCORE_LEN], torch.float32, is_output=True), + TensorSpec("topk_idxs", [B, S, SCORE_LEN], torch.int32, is_output=True), + ScalarSpec("start_pos", torch.int32, START_POS), + ScalarSpec("offset", torch.int32, OFFSET), + ] + + +if __name__ == "__main__": + import argparse + from golden import RunConfig, run_jit, topk_pair_compare + + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--platform", type=str, default="a2a3", + choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + parser.add_argument("-d", "--device", type=int, default=0) + parser.add_argument("--runtime-profiling", action="store_true", default=False) + args = parser.parse_args() + + result = run_jit( + fn=indexer_test, + specs=build_tensor_specs(), + golden_fn=golden_indexer, + config=RunConfig( + rtol=1e-3, + atol=1e-3, + compile=dict(dump_passes=True), + compare_fn={ + "topk_idxs": topk_pair_compare("score"), + }, + runtime=dict( + platform=args.platform, + device_id=args.device, + runtime_profiling=args.runtime_profiling, + ), + ), + ) + if not result.passed: + if result.error: + print(result.error) + raise SystemExit(1) diff --git a/models/deepseek/v4/indexer_draft.py b/models/deepseek/v4/indexer_draft.py deleted file mode 100644 index 65d6189..0000000 --- a/models/deepseek/v4/indexer_draft.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) PyPTO Contributors. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ----------------------------------------------------------------------------------------------------------- -"""DeepSeek-V4 Indexer (decode). Mirrors model.py Indexer (line 380-433); -golden is a port of forward's decode branch (prefill `start_pos == 0` path is omitted). -The inner Compressor is invoked via golden_compressor (placeholder).""" - - -import pypto.language as pl - -from compressor_draft import golden_compressor - - -B = 16 # demo 4 -S = 1 -T = B * S -EPS = 1e-6 - -D = 4096 # flash:4096 pro:7168 -Q_LORA = 1024 # flash:1024 pro:1536 -ROPE_HEAD_DIM = 64 - -IDX_N_HEADS = 64 -IDX_HEAD_DIM = 128 -IDX_NOPE_HEAD_DIM = IDX_HEAD_DIM - ROPE_HEAD_DIM -IDX_TOPK = 512 # flash:512 pro:1024 -IDX_SOFTMAX_SCALE = IDX_HEAD_DIM ** -0.5 - -COMPRESS_RATIO = 4 -ROTATE = True # inner compressor always uses rotate=True (model.py:398) -OVERLAP = COMPRESS_RATIO == 4 -COFF = 1 + int(OVERLAP) -INNER_OUT_DIM = COFF * IDX_HEAD_DIM -STATE_LEN = COFF * COMPRESS_RATIO - -MAX_SEQ_LEN = 4096 -IDX_KV_LEN = MAX_SEQ_LEN // COMPRESS_RATIO - -START_POS = 3 # default for ScalarSpec; >0 (decode) and (START_POS+1)%COMPRESS_RATIO==0 to cover the full inner-compressor path -SHOULD_COMPRESS = COMPRESS_RATIO != 0 and ((START_POS + 1) % COMPRESS_RATIO) == 0 -OFFSET = 128 # default for ScalarSpec; = win in attention orch; added to topk_idxs (model.py:432) - - -@pl.jit.inline -def indexer( - x: pl.Tensor[[B, S, D], pl.BF16], - qr: pl.Tensor[[T, Q_LORA], pl.BF16], - wq_b: pl.Tensor[[Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM], pl.BF16], - weights_proj: pl.Tensor[[D, IDX_N_HEADS], pl.BF16], - cos: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], # caller passes freqs_cis[start_pos] - sin: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], - hadamard: pl.Tensor[[IDX_HEAD_DIM, IDX_HEAD_DIM], pl.BF16], # shared by q rotation and inner Compressor - inner_wkv: pl.Tensor[[INNER_OUT_DIM, D], pl.BF16], - inner_wgate: pl.Tensor[[INNER_OUT_DIM, D], pl.BF16], - inner_ape: pl.Tensor[[COMPRESS_RATIO, INNER_OUT_DIM], pl.FP32], - inner_norm_w: pl.Tensor[[IDX_HEAD_DIM], pl.BF16], - inner_cos: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], # caller passes freqs_cis[start_pos+1-ratio] - inner_sin: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], - inner_kv_state: pl.Tensor[[B, STATE_LEN, INNER_OUT_DIM], pl.FP32], - inner_score_state: pl.Tensor[[B, STATE_LEN, INNER_OUT_DIM], pl.FP32], - idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16], - start_pos: pl.Scalar[pl.INT32], # decode step; varies per call - offset: pl.Scalar[pl.INT32], # added to topk_idxs (= win from attention orch) - topk_idxs: pl.Tensor[[T, IDX_TOPK], pl.INT32], -): - # TODO: kernel implementation - return topk_idxs - - -@pl.jit -def indexer_test( - x: pl.Tensor[[B, S, D], pl.BF16], - qr: pl.Tensor[[T, Q_LORA], pl.BF16], - wq_b: pl.Tensor[[Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM], pl.BF16], - weights_proj: pl.Tensor[[D, IDX_N_HEADS], pl.BF16], - cos: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], - sin: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], - hadamard: pl.Tensor[[IDX_HEAD_DIM, IDX_HEAD_DIM], pl.BF16], - inner_wkv: pl.Tensor[[INNER_OUT_DIM, D], pl.BF16], - inner_wgate: pl.Tensor[[INNER_OUT_DIM, D], pl.BF16], - inner_ape: pl.Tensor[[COMPRESS_RATIO, INNER_OUT_DIM], pl.FP32], - inner_norm_w: pl.Tensor[[IDX_HEAD_DIM], pl.BF16], - inner_cos: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], - inner_sin: pl.Tensor[[1, ROPE_HEAD_DIM], pl.BF16], - inner_kv_state: pl.InOut[pl.Tensor[[B, STATE_LEN, INNER_OUT_DIM], pl.FP32]], - inner_score_state: pl.InOut[pl.Tensor[[B, STATE_LEN, INNER_OUT_DIM], pl.FP32]], - idx_kv_cache: pl.InOut[pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16]], - start_pos: pl.Scalar[pl.INT32], - offset: pl.Scalar[pl.INT32], - topk_idxs: pl.Out[pl.Tensor[[T, IDX_TOPK], pl.INT32]], -): - topk_idxs = indexer( - x, qr, wq_b, weights_proj, cos, sin, hadamard, - inner_wkv, inner_wgate, inner_ape, inner_norm_w, - inner_cos, inner_sin, - inner_kv_state, inner_score_state, idx_kv_cache, - start_pos, offset, - topk_idxs, - ) - return topk_idxs - - -def golden_indexer(tensors): - """Torch reference for Indexer.forward (decode branch; prefill omitted; W8A8C16 quant ops are identity in golden).""" - import torch - - x = tensors["x"] - qr = tensors["qr"].float() - wq_b = tensors["wq_b"].float() - weights_proj = tensors["weights_proj"].float() - cos = tensors["cos"].float() - sin = tensors["sin"].float() - hadamard = tensors["hadamard"].float() - idx_kv_cache = tensors["idx_kv_cache"] - - start_pos = int(tensors["start_pos"]) - compress_ratio = COMPRESS_RATIO - offset = int(tensors["offset"]) - - bsz, seqlen, _ = x.shape - ratio, rd = compress_ratio, ROPE_HEAD_DIM - end_pos = start_pos + seqlen - - if start_pos == 0: - return - - # W8A8C16: wq_b W8 per-channel int8; qr A8 per-token int8. - q = (qr @ wq_b).view(T, IDX_N_HEADS, IDX_HEAD_DIM) - - x_pair = q[..., -rd:].unflatten(-1, (-1, 2)) - x0, x1 = x_pair[..., 0], x_pair[..., 1] - cos_v, sin_v = cos.view(-1), sin.view(-1) - y0 = x0 * cos_v - x1 * sin_v - y1 = x0 * sin_v + x1 * cos_v - q = torch.cat([q[..., :-rd], torch.stack([y0, y1], dim=-1).flatten(-2)], dim=-1) - - q = (q.view(-1, IDX_HEAD_DIM) @ hadamard).view(T, IDX_N_HEADS, IDX_HEAD_DIM) - # W8A8C16: A8 per-token-head int8 quant of q here (consumed by LI batch_matmul below). - # flash: fp4_act_quant on q (FP4 simulation). - - inner_out = torch.zeros(bsz, IDX_HEAD_DIM, dtype=torch.bfloat16) - inner_tensors = { - "x": x, - "kv_state": tensors["inner_kv_state"], - "score_state": tensors["inner_score_state"], - "wkv": tensors["inner_wkv"], - "wgate": tensors["inner_wgate"], - "ape": tensors["inner_ape"], - "norm_w": tensors["inner_norm_w"], - "cos": tensors["inner_cos"], - "sin": tensors["inner_sin"], - "hadamard": tensors["hadamard"], - "start_pos": tensors["start_pos"], - "out": inner_out, - } - # Placeholder call — compressor's golden currently uses module-level constants - # (HEAD_DIM=512, ROTATE=False), so this won't run end-to-end without refactor. - golden_compressor(inner_tensors) - should_compress = compress_ratio != 0 and ((start_pos + 1) % compress_ratio) == 0 - if should_compress: - idx_kv_cache[:bsz, start_pos // ratio] = inner_out - - weights = (x.float().view(bsz, -1) @ weights_proj) * (IDX_SOFTMAX_SCALE * IDX_N_HEADS ** -0.5) - weights = weights.view(T, IDX_N_HEADS) - - cache_len = end_pos // ratio - kv_view = idx_kv_cache[:bsz, :cache_len].float() - # W8A8C16: LI batch_matmul Int8. q A8 per-token-head int8; kv_view (Indexer Cache) C8 per-token-head int8. - # flash: q/kv via FP4 simulation (full Hadamard rotation + fp4_act_quant). - score = torch.einsum("thd,btd->bht", q, kv_view) - score = (torch.relu(score) * weights.view(bsz, IDX_N_HEADS, 1)).sum(dim=1) - - k = min(IDX_TOPK, cache_len) - _, idx = score.topk(k, dim=-1) - topk_idxs = torch.full((bsz, IDX_TOPK), -1, dtype=torch.int32) - topk_idxs[:, :k] = idx.to(torch.int32) - topk_idxs[:, :k] += offset - - tensors["topk_idxs"][:] = topk_idxs - - -def build_tensor_specs(): - import torch # type: ignore[import] - from golden import ScalarSpec, TensorSpec - - def init_x(): - return torch.randn(B, S, D) * 0.1 - def init_qr(): - return torch.randn(T, Q_LORA) * 0.1 - def init_wq_b(): - return torch.randn(Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM) / Q_LORA ** 0.5 - def init_weights_proj(): - return torch.randn(D, IDX_N_HEADS) / D ** 0.5 - def init_cos(): - return torch.cos(torch.arange(ROPE_HEAD_DIM).reshape(1, ROPE_HEAD_DIM) * 1e-3) - def init_sin(): - return torch.sin(torch.arange(ROPE_HEAD_DIM).reshape(1, ROPE_HEAD_DIM) * 1e-3) - def init_hadamard(): - return torch.eye(IDX_HEAD_DIM) - def init_inner_wkv(): - return torch.randn(INNER_OUT_DIM, D) / D ** 0.5 - def init_inner_wgate(): - return torch.randn(INNER_OUT_DIM, D) / D ** 0.5 - def init_inner_ape(): - return torch.randn(COMPRESS_RATIO, INNER_OUT_DIM) * 0.01 - def init_inner_norm_w(): - return torch.ones(IDX_HEAD_DIM) - def init_inner_cos(): - return torch.cos(torch.arange(ROPE_HEAD_DIM).reshape(1, ROPE_HEAD_DIM) * 1e-3) - def init_inner_sin(): - return torch.sin(torch.arange(ROPE_HEAD_DIM).reshape(1, ROPE_HEAD_DIM) * 1e-3) - def init_inner_kv_state(): - return torch.zeros(B, STATE_LEN, INNER_OUT_DIM) - def init_inner_score_state(): - return torch.full((B, STATE_LEN, INNER_OUT_DIM), float("-inf")) - def init_idx_kv_cache(): - return torch.zeros(B, IDX_KV_LEN, IDX_HEAD_DIM) - - return [ - TensorSpec("x", [B, S, D], torch.bfloat16, init_value=init_x), - TensorSpec("qr", [T, Q_LORA], torch.bfloat16, init_value=init_qr), - TensorSpec("wq_b", [Q_LORA, IDX_N_HEADS * IDX_HEAD_DIM], torch.bfloat16, init_value=init_wq_b), - TensorSpec("weights_proj", [D, IDX_N_HEADS], torch.bfloat16, init_value=init_weights_proj), - TensorSpec("cos", [1, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_cos), - TensorSpec("sin", [1, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_sin), - TensorSpec("hadamard", [IDX_HEAD_DIM, IDX_HEAD_DIM], torch.bfloat16, init_value=init_hadamard), - TensorSpec("inner_wkv", [INNER_OUT_DIM, D], torch.bfloat16, init_value=init_inner_wkv), - TensorSpec("inner_wgate", [INNER_OUT_DIM, D], torch.bfloat16, init_value=init_inner_wgate), - TensorSpec("inner_ape", [COMPRESS_RATIO, INNER_OUT_DIM], torch.float32, init_value=init_inner_ape), - TensorSpec("inner_norm_w", [IDX_HEAD_DIM], torch.bfloat16, init_value=init_inner_norm_w), - TensorSpec("inner_cos", [1, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_inner_cos), - TensorSpec("inner_sin", [1, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_inner_sin), - TensorSpec("inner_kv_state", [B, STATE_LEN, INNER_OUT_DIM], torch.float32, init_value=init_inner_kv_state), - TensorSpec("inner_score_state", [B, STATE_LEN, INNER_OUT_DIM], torch.float32, init_value=init_inner_score_state), - TensorSpec("idx_kv_cache", [B, IDX_KV_LEN, IDX_HEAD_DIM], torch.bfloat16, init_value=init_idx_kv_cache), - ScalarSpec("start_pos", torch.int32, START_POS), - ScalarSpec("offset", torch.int32, OFFSET), - TensorSpec("topk_idxs", [T, IDX_TOPK], torch.int32, is_output=True), - ] - - -if __name__ == "__main__": - import argparse - from golden import RunConfig, run_jit - - parser = argparse.ArgumentParser() - parser.add_argument("-p", "--platform", type=str, default="a2a3", - choices=["a2a3", "a2a3sim", "a5", "a5sim"]) - parser.add_argument("-d", "--device", type=int, default=0) - parser.add_argument("--runtime-profiling", action="store_true", default=False) - args = parser.parse_args() - - result = run_jit( - fn=indexer_test, - specs=build_tensor_specs(), - golden_fn=golden_indexer, - config=RunConfig( - rtol=1e-3, - atol=1e-3, - compile=dict(dump_passes=True), - runtime=dict( - platform=args.platform, - device_id=args.device, - runtime_profiling=args.runtime_profiling, - ), - ), - ) - if not result.passed: - if result.error: - print(result.error) - raise SystemExit(1)