diff --git a/models/deepseek/v4/compressor_ratio128.py b/models/deepseek/v4/compressor_ratio128.py index 3e5daa1..9c2057e 100644 --- a/models/deepseek/v4/compressor_ratio128.py +++ b/models/deepseek/v4/compressor_ratio128.py @@ -8,8 +8,8 @@ # ----------------------------------------------------------------------------------------------------------- """DeepSeek-V4 KV Compressor (decode incremental, ratio=128 non-overlap). -Uses non-overlapping state layout with 128 slots. Loop-based softmax+pool -over all slots. No state shift needed.""" +Uses non-overlapping state layout with 128 slots. +Online softmax+pool over all slots. No state shift needed.""" import pypto.language as pl @@ -27,289 +27,378 @@ # kernel-local (ratio-128 non-overlap compressor) COMPRESS_RATIO = 128 -ROTATE = False -OVERLAP = False +MAX_SEQ_LEN = 4096 +IDX_KV_LEN = MAX_SEQ_LEN // COMPRESS_RATIO +FP32_NEG_INF = -3.4028234663852886e38 + COFF = 1 OUT_DIM = COFF * HEAD_DIM # 512 STATE_LEN = COFF * COMPRESS_RATIO # 128 START_POS = 127 # ScalarSpec default; (START_POS+1)%COMPRESS_RATIO==0 -SHOULD_COMPRESS = COMPRESS_RATIO != 0 and ((START_POS + 1) % COMPRESS_RATIO) == 0 -APE_ROW = START_POS % COMPRESS_RATIO # 127 -SCATTER_SLOT = APE_ROW # 127 (no overlap) # tiling +ROPE_CHUNK = 32 K_CHUNK = 512 -OUT_CHUNK = 64 +OUT_CHUNK = 128 + HEAD_CHUNK = 128 K_BLOCKS = D // K_CHUNK # 8 -OUT_BLOCKS = OUT_DIM // OUT_CHUNK # 8 +OUT_BLOCKS = OUT_DIM // OUT_CHUNK # 4 HEAD_BLOCKS = HEAD_DIM // HEAD_CHUNK # 4 @pl.jit.inline def compressor( x: pl.Tensor[[B, S, D], pl.BF16], + kv: pl.Tensor[[B, S, HEAD_DIM], pl.FP32], kv_state: pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32], score_state: pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32], - wkv: pl.Tensor[[OUT_DIM, D], pl.BF16], - wgate: pl.Tensor[[OUT_DIM, D], pl.BF16], + wkv: pl.Tensor[[D, OUT_DIM], pl.BF16], + wgate: pl.Tensor[[D, OUT_DIM], pl.BF16], ape: pl.Tensor[[COMPRESS_RATIO, OUT_DIM], pl.FP32], - norm_w: pl.Tensor[[HEAD_DIM], pl.BF16], - cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], - sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], + norm_w: pl.Tensor[[HEAD_DIM], pl.FP32], + cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + even_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + odd_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], hadamard: pl.Tensor[[HEAD_DIM, HEAD_DIM], pl.BF16], + kv_cache: pl.Tensor[[B, IDX_KV_LEN, HEAD_DIM], pl.BF16], start_pos: pl.Scalar[pl.INT32], - out: pl.Tensor[[B, HEAD_DIM], pl.BF16], + rotate: pl.Scalar[pl.BOOL], ): - x_flat = pl.reshape(x, [B, D]) + x_flat = pl.reshape(x, [B * S, D]) + kv_proj = pl.create_tensor([B * S, OUT_DIM], dtype=pl.FP32) + score_proj = pl.create_tensor([B * S, OUT_DIM], dtype=pl.FP32) + ape_row = pl.cast(start_pos % COMPRESS_RATIO, target_type=pl.INDEX) + compress_rem = (start_pos + 1) % COMPRESS_RATIO + score_ape = pl.create_tensor([B * S, OUT_DIM], dtype=pl.FP32) + # Non-overlap: scatter into slot = ape_row + state_col0 = ape_row * OUT_DIM kv_state_flat = pl.reshape(kv_state, [B, STATE_LEN * OUT_DIM]) score_state_flat = pl.reshape(score_state, [B, STATE_LEN * OUT_DIM]) + for o0 in pl.parallel(0, OUT_DIM, OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_score_proj"): + x_tile = x_flat[:, 0 : K_CHUNK] + wkv_tile = wkv[0 : K_CHUNK, o0 : o0 + OUT_CHUNK] + wgate_tile = wgate[0 : K_CHUNK, o0 : o0 + OUT_CHUNK] + kv_acc = pl.matmul(x_tile, wkv_tile, out_dtype=pl.FP32) + score_acc = pl.matmul(x_tile, wgate_tile, out_dtype=pl.FP32) + + for k0 in pl.range(K_CHUNK, D, K_CHUNK): + x_tile = x_flat[:, k0 : k0 + K_CHUNK] + wkv_tile = wkv[k0 : k0 + K_CHUNK, o0 : o0 + OUT_CHUNK] + wgate_tile = wgate[k0 : k0 + K_CHUNK, o0 : o0 + OUT_CHUNK] + kv_acc = pl.matmul_acc(kv_acc, x_tile, wkv_tile) + score_acc = pl.matmul_acc(score_acc, x_tile, wgate_tile) + + kv_proj = pl.assemble(kv_proj, kv_acc, [0, o0]) + score_proj = pl.assemble(score_proj, score_acc, [0, o0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_ape"): + score_tile = score_proj[:, o0 : o0 + OUT_CHUNK] + ape_tile = ape[ape_row : ape_row + 1, o0 : o0 + OUT_CHUNK] + ape_base = pl.full([B * S, OUT_CHUNK], dtype=pl.FP32, value=0.0) + score_tile = pl.add(score_tile, pl.col_expand(ape_base, ape_tile)) + score_ape = pl.assemble(score_ape, score_tile, [0, o0]) - kv_fp32 = pl.create_tensor([B, OUT_DIM], dtype=pl.FP32) - score_fp32 = pl.create_tensor([B, OUT_DIM], dtype=pl.FP32) - slot_off = SCATTER_SLOT * OUT_DIM - - for ob in pl.parallel(0, OUT_BLOCKS, 1): - oc0 = ob * OUT_CHUNK - # Block 1a (Cube): kv = x @ wkv.T - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_proj"): - a0 = pl.slice(x_flat, [B, K_CHUNK], [0, 0]) - b0 = pl.slice(wkv, [OUT_CHUNK, K_CHUNK], [oc0, 0]) - kv_acc = pl.matmul(a0, b0, out_dtype=pl.FP32, b_trans=True) - for kb in pl.range(1, K_BLOCKS): - a_i = pl.slice(x_flat, [B, K_CHUNK], [0, kb * K_CHUNK]) - b_i = pl.slice(wkv, [OUT_CHUNK, K_CHUNK], [oc0, kb * K_CHUNK]) - kv_acc = pl.matmul_acc(kv_acc, a_i, b_i, b_trans=True) - kv_fp32 = pl.assemble(kv_fp32, kv_acc, [0, oc0]) - - # Block 1b (Cube): score = x @ wgate.T - with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_proj"): - a0g = pl.slice(x_flat, [B, K_CHUNK], [0, 0]) - b0g = pl.slice(wgate, [OUT_CHUNK, K_CHUNK], [oc0, 0]) - sc_acc = pl.matmul(a0g, b0g, out_dtype=pl.FP32, b_trans=True) - for kb in pl.range(1, K_BLOCKS): - a_ig = pl.slice(x_flat, [B, K_CHUNK], [0, kb * K_CHUNK]) - b_ig = pl.slice(wgate, [OUT_CHUNK, K_CHUNK], [oc0, kb * K_CHUNK]) - sc_acc = pl.matmul_acc(sc_acc, a_ig, b_ig, b_trans=True) - score_fp32 = pl.assemble(score_fp32, sc_acc, [0, oc0]) - - # Block 2 (Vector): score += ape[APE_ROW] - with pl.at(level=pl.Level.CORE_GROUP, name_hint="ape_add"): - sc = pl.slice(score_fp32, [B, OUT_CHUNK], [0, oc0]) - ape_row = pl.slice(ape, [1, OUT_CHUNK], [APE_ROW, oc0]) - ones_b = pl.full([B, OUT_CHUNK], dtype=pl.FP32, value=1.0) - ape_broadcast = pl.col_expand_mul(ones_b, ape_row) - sc = pl.add(sc, ape_broadcast) - score_fp32 = pl.assemble(score_fp32, sc, [0, oc0]) - - # Block 3 (Vector): scatter current kv/score into state with pl.at(level=pl.Level.CORE_GROUP, name_hint="state_scatter"): - kv_chunk = pl.slice(kv_fp32, [B, OUT_CHUNK], [0, oc0]) - kv_state_flat = pl.assemble(kv_state_flat, kv_chunk, [0, slot_off + oc0]) - sc_chunk = pl.slice(score_fp32, [B, OUT_CHUNK], [0, oc0]) - score_state_flat = pl.assemble(score_state_flat, sc_chunk, [0, slot_off + oc0]) - - # Reshape state to per-state-row 2D views - kv_state_per_row = pl.reshape(kv_state_flat, [B * STATE_LEN, OUT_DIM]) - score_state_per_row = pl.reshape(score_state_flat, [B * STATE_LEN, OUT_DIM]) - - # Block 5+6 (Vector): softmax+pool over 128 slots via loop. - pooled = pl.create_tensor([B, HEAD_DIM], dtype=pl.FP32) - for b_idx in pl.parallel(0, B, 1): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="softmax_pool_loop"): - row_b = b_idx * STATE_LEN - for hb in pl.range(HEAD_BLOCKS): + kv_tile = kv_proj[:, o0 : o0 + OUT_CHUNK] + score_tile = score_ape[:, o0 : o0 + OUT_CHUNK] + kv_state_flat = pl.assemble(kv_state_flat, kv_tile, [0, state_col0 + o0]) + score_state_flat = pl.assemble(score_state_flat, score_tile, [0, state_col0 + o0]) + + pooled_kv = pl.create_tensor([B * S, HEAD_DIM], dtype=pl.FP32) + normed_kv = pl.create_tensor([B * S, HEAD_DIM], dtype=pl.BF16) + kv_flat = pl.reshape(kv, [B * S, HEAD_DIM]) + + if compress_rem == 0: + for hb in pl.parallel(0, HEAD_BLOCKS, 1): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="softmax_pool"): h0 = hb * HEAD_CHUNK - # Max via loop over all 128 slots. - s_max = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b, h0]) - for s in pl.range(1, STATE_LEN): - sr_max = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + s, h0]) - s_max = pl.maximum(s_max, sr_max) - - # Exp sum and weighted sum via loop. - e_sum = pl.full([1, HEAD_CHUNK], dtype=pl.FP32, value=0.0) - weighted = pl.full([1, HEAD_CHUNK], dtype=pl.FP32, value=0.0) - for s in pl.range(STATE_LEN): - sr_exp = pl.slice(score_state_per_row, [1, HEAD_CHUNK], [row_b + s, h0]) - kv_row = pl.slice(kv_state_per_row, [1, HEAD_CHUNK], [row_b + s, h0]) - e_row = pl.exp(pl.sub(sr_exp, s_max)) - e_sum = pl.add(e_sum, e_row) - weighted = pl.add(weighted, pl.mul(e_row, kv_row)) - - pooled_chunk = pl.div(weighted, e_sum) - pooled = pl.assemble(pooled, pooled_chunk, [b_idx, h0]) - - # No block 7 (no shift in non-overlap mode). - - # Reshape state back to 3D - kv_state = pl.reshape(kv_state_per_row, [B, STATE_LEN, OUT_DIM]) - score_state = pl.reshape(score_state_per_row, [B, STATE_LEN, OUT_DIM]) - - # Block 8 (Vector): RMSNorm pooled with norm_w over HEAD_DIM. - normed_pooled = pl.create_tensor([B, HEAD_DIM], dtype=pl.FP32) - norm_w_2d = pl.reshape(norm_w, [1, HEAD_DIM]) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): - partial_sq = pl.full([1, B], dtype=pl.FP32, value=0.0) - for hb in pl.range(HEAD_BLOCKS): - h0 = hb * HEAD_CHUNK - pc = pl.slice(pooled, [B, HEAD_CHUNK], [0, h0]) - partial_sq = pl.add( - partial_sq, - pl.reshape(pl.row_sum(pl.mul(pc, pc)), [1, B]), - ) - inv_rms = pl.reshape( - pl.recip(pl.sqrt(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS))), - [B, 1], - ) - for hb in pl.range(HEAD_BLOCKS): - h0 = hb * HEAD_CHUNK - nc = pl.slice(pooled, [B, HEAD_CHUNK], [0, h0]) - nw_chunk = pl.cast( - pl.slice(norm_w_2d, [1, HEAD_CHUNK], [0, h0]), - target_type=pl.FP32, - ) - normed = pl.col_expand_mul(pl.row_expand_mul(nc, inv_rms), nw_chunk) - normed_pooled = pl.assemble(normed_pooled, normed, [0, h0]) - - # Block 11a (Vector): cast non-rope range to BF16 and store to out. - with pl.at(level=pl.Level.CORE_GROUP, name_hint="store_nope"): - nope_chunk = pl.slice(normed_pooled, [B, NOPE_HEAD_DIM], [0, 0]) - out = pl.assemble(out, pl.cast(nope_chunk, target_type=pl.BF16), [0, 0]) - - # Block 9 + 11b (Vector): half-vector RoPE on the last ROPE_HEAD_DIM cols, then store. - HALF_RD = ROPE_HEAD_DIM // 2 - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_store"): - x_lo = pl.slice(normed_pooled, [B, HALF_RD], [0, NOPE_HEAD_DIM]) - x_hi = pl.slice(normed_pooled, [B, HALF_RD], [0, NOPE_HEAD_DIM + HALF_RD]) - cos_fp32 = pl.cast(cos, target_type=pl.FP32) - sin_fp32 = pl.cast(sin, target_type=pl.FP32) - y_lo = pl.sub(pl.col_expand_mul(x_lo, cos_fp32), pl.col_expand_mul(x_hi, sin_fp32)) - y_hi = pl.add(pl.col_expand_mul(x_lo, sin_fp32), pl.col_expand_mul(x_hi, cos_fp32)) - out = pl.assemble(out, pl.cast(y_lo, target_type=pl.BF16), [0, NOPE_HEAD_DIM]) - out = pl.assemble(out, pl.cast(y_hi, target_type=pl.BF16), [0, NOPE_HEAD_DIM + HALF_RD]) - - return out + # Initialize m/l/o from last slot + last_col0 = (STATE_LEN - 1) * OUT_DIM + h0 + mi = score_state_flat[:, last_col0 : last_col0 + HEAD_CHUNK] + li = pl.exp(pl.sub(mi, mi)) + oi = kv_state_flat[:, last_col0 : last_col0 + HEAD_CHUNK] + + # Online softmax over all remaining slots + for s in pl.range(0, STATE_LEN - 1): + col0 = s * OUT_DIM + h0 + slot_score = score_state_flat[:, col0 : col0 + HEAD_CHUNK] + slot_kv = kv_state_flat[:, col0 : col0 + HEAD_CHUNK] + mi_next = pl.maximum(mi, slot_score) + alpha = pl.exp(pl.sub(mi, mi_next)) + beta = pl.exp(pl.sub(slot_score, mi_next)) + li = pl.add(pl.mul(alpha, li), beta) + oi = pl.add(pl.mul(oi, alpha), pl.mul(slot_kv, beta)) + mi = mi_next + + pooled_chunk = pl.div(oi, li) + pooled_kv = pl.assemble(pooled_kv, pooled_chunk, [0, h0]) + + # No state shift for non-overlap + + # RMSNorm with BF16 intermediate + norm_w_2d = pl.reshape(norm_w, [1, HEAD_DIM]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): + partial_sq = pl.full([1, B * S], dtype=pl.FP32, value=0.0) + for k0 in pl.range(0, HEAD_DIM, HEAD_CHUNK): + kv_rms_chunk = pl.cast( + pl.cast(pooled_kv[:, k0 : k0 + HEAD_CHUNK], target_type=pl.BF16), + target_type=pl.FP32, + ) + partial_sq = pl.add( + partial_sq, + pl.reshape(pl.row_sum(pl.mul(kv_rms_chunk, kv_rms_chunk)), [1, B * S]), + ) + + variance = pl.reshape(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS), [B * S, 1]) + inv_rms = pl.recip(pl.sqrt(variance)) + for k0 in pl.range(0, HEAD_DIM, HEAD_CHUNK): + kv_norm_chunk = pl.cast( + pl.cast(pooled_kv[:, k0 : k0 + HEAD_CHUNK], target_type=pl.BF16), + target_type=pl.FP32, + ) + gamma = norm_w_2d[:, k0 : k0 + HEAD_CHUNK] + normed_chunk = pl.col_expand_mul(pl.row_expand_mul(kv_norm_chunk, inv_rms), gamma) + normed_kv = pl.assemble(normed_kv, pl.cast(normed_chunk, target_type=pl.BF16), [0, k0]) + + # Selector-based RoPE + kv_rope = pl.create_tensor([B * S, ROPE_HEAD_DIM], dtype=pl.BF16) + kv_proj_even = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.FP32) + kv_proj_odd = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.FP32) + rope_even = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.BF16) + rope_odd = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.BF16) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_slice"): + kv_rope = pl.assemble(kv_rope, normed_kv[:, NOPE_HEAD_DIM : HEAD_DIM], [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_slice"): + kv_rope_tile = kv_rope[:, 0 : ROPE_CHUNK] + even_select_tile = even_select[0 : ROPE_CHUNK, :] + odd_select_tile = odd_select[0 : ROPE_CHUNK, :] + even_acc = pl.matmul(kv_rope_tile, even_select_tile, out_dtype=pl.FP32) + odd_acc = pl.matmul(kv_rope_tile, odd_select_tile, out_dtype=pl.FP32) + + for r0 in pl.range(ROPE_CHUNK, ROPE_HEAD_DIM, ROPE_CHUNK): + kv_rope_tile = kv_rope[:, r0 : r0 + ROPE_CHUNK] + even_select_tile = even_select[r0 : r0 + ROPE_CHUNK, :] + odd_select_tile = odd_select[r0 : r0 + ROPE_CHUNK, :] + even_acc = pl.matmul_acc(even_acc, kv_rope_tile, even_select_tile) + odd_acc = pl.matmul_acc(odd_acc, kv_rope_tile, odd_select_tile) + kv_proj_even = pl.assemble(kv_proj_even, even_acc, [0, 0]) + kv_proj_odd = pl.assemble(kv_proj_odd, odd_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_apply"): + even_tile = kv_proj_even[:, :] + odd_tile = kv_proj_odd[:, :] + rope_even_acc = pl.cast(pl.sub(pl.col_expand_mul(even_tile, cos), pl.col_expand_mul(odd_tile, sin)), target_type=pl.BF16) + rope_odd_acc = pl.cast(pl.add(pl.col_expand_mul(even_tile, sin), pl.col_expand_mul(odd_tile, cos)), target_type=pl.BF16) + rope_even = pl.assemble(rope_even, rope_even_acc, [0, 0]) + rope_odd = pl.assemble(rope_odd, rope_odd_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_assemble"): + rope_even_tile = rope_even[:, 0 : ROPE_CHUNK] + rope_odd_tile = rope_odd[:, 0 : ROPE_CHUNK] + even_select_tile_t = even_select[:, 0 : ROPE_CHUNK] + odd_select_tile_t = odd_select[:, 0 : ROPE_CHUNK] + rope_acc = pl.matmul(rope_even_tile, even_select_tile_t, out_dtype=pl.FP32, b_trans=True) + rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) + + for r0 in pl.range(ROPE_CHUNK, ROPE_HEAD_DIM // 2, ROPE_CHUNK): + rope_even_tile = rope_even[:, r0 : r0 + ROPE_CHUNK] + rope_odd_tile = rope_odd[:, r0 : r0 + ROPE_CHUNK] + even_select_tile_t = even_select[:, r0 : r0 + ROPE_CHUNK] + odd_select_tile_t = odd_select[:, r0 : r0 + ROPE_CHUNK] + rope_acc = pl.matmul_acc(rope_acc, rope_even_tile, even_select_tile_t, b_trans=True) + rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_write"): + normed_kv = pl.assemble(normed_kv, pl.cast(rope_acc, target_type=pl.BF16), [0, NOPE_HEAD_DIM]) + + if rotate: + for o0 in pl.range(0, HEAD_DIM, OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_hadamard"): + kv_proj_tile = normed_kv[:, 0 : HEAD_DIM] + hadamard_tile = hadamard[0 : HEAD_DIM, o0 : o0 + OUT_CHUNK] + kv_hadamard_acc = pl.matmul(kv_proj_tile, hadamard_tile, out_dtype=pl.FP32) + kv_flat = pl.assemble(kv_flat, kv_hadamard_acc, [0, o0]) + else: + for o0 in pl.parallel(0, HEAD_DIM, OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_write"): + kv_out_tile = normed_kv[:, o0 : o0 + OUT_CHUNK] + kv_flat = pl.assemble(kv_flat, pl.cast(kv_out_tile, target_type=pl.FP32), [0, o0]) + + kv_cache_flat = pl.reshape(kv_cache, [B * IDX_KV_LEN, HEAD_DIM]) + cache_col = start_pos // COMPRESS_RATIO + for b_idx in pl.parallel(B): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_cache_write"): + cache_row = b_idx * IDX_KV_LEN + cache_col + cache_kv_row = kv_flat[b_idx : b_idx + 1, 0 : HEAD_DIM] + kv_cache_flat = pl.assemble(kv_cache_flat, pl.cast(cache_kv_row, target_type=pl.BF16), [cache_row, 0]) + kv_cache = pl.reshape(kv_cache_flat, [B, IDX_KV_LEN, HEAD_DIM]) + + kv_state = pl.reshape(kv_state_flat, [B, STATE_LEN, OUT_DIM]) + score_state = pl.reshape(score_state_flat, [B, STATE_LEN, OUT_DIM]) + kv = pl.reshape(kv_flat, [B, S, HEAD_DIM]) + return kv, kv_state, score_state, kv_cache @pl.jit def compressor_test( x: pl.Tensor[[B, S, D], pl.BF16], + kv: pl.Out[pl.Tensor[[B, S, HEAD_DIM], pl.FP32]], kv_state: pl.Out[pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32]], score_state: pl.Out[pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32]], - wkv: pl.Tensor[[OUT_DIM, D], pl.BF16], - wgate: pl.Tensor[[OUT_DIM, D], pl.BF16], + wkv: pl.Tensor[[D, OUT_DIM], pl.BF16], + wgate: pl.Tensor[[D, OUT_DIM], pl.BF16], ape: pl.Tensor[[COMPRESS_RATIO, OUT_DIM], pl.FP32], - norm_w: pl.Tensor[[HEAD_DIM], pl.BF16], - cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], - sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.BF16], + norm_w: pl.Tensor[[HEAD_DIM], pl.FP32], + cos: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + sin: pl.Tensor[[1, ROPE_HEAD_DIM // 2], pl.FP32], + even_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], + odd_select: pl.Tensor[[ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], pl.BF16], hadamard: pl.Tensor[[HEAD_DIM, HEAD_DIM], pl.BF16], + kv_cache: pl.Out[pl.Tensor[[B, IDX_KV_LEN, HEAD_DIM], pl.BF16]], start_pos: pl.Scalar[pl.INT32], - out: pl.Out[pl.Tensor[[B, HEAD_DIM], pl.BF16]], + rotate: pl.Scalar[pl.BOOL], ): - out = compressor( - x, kv_state, score_state, wkv, wgate, ape, norm_w, cos, sin, hadamard, start_pos, out, + kv, kv_state, score_state, kv_cache = compressor( + x, kv, kv_state, score_state, wkv, wgate, ape, norm_w, cos, sin, even_select, odd_select, hadamard, kv_cache, start_pos, rotate ) - return out + return kv, kv_state, score_state, kv_cache def golden_compressor(tensors): - """Torch reference for Compressor.forward (decode branch, ratio=128).""" + """Torch reference for Compressor.forward (decode branch, ratio=128 non-overlap).""" import torch - x = tensors["x"] + x = tensors["x"].float() kv_state = tensors["kv_state"] score_state = tensors["score_state"] wkv = tensors["wkv"].float() wgate = tensors["wgate"].float() - ape = tensors["ape"].float() - norm_w = tensors["norm_w"].float() - cos = tensors["cos"].float() - sin = tensors["sin"].float() + ape = tensors["ape"] + norm_w = tensors["norm_w"] + cos = tensors["cos"] + sin = tensors["sin"] hadamard = tensors["hadamard"].float() - + kv_cache = tensors["kv_cache"] + start_pos = int(tensors["start_pos"]) + rotate = bool(tensors["rotate"]) bsz, _, _ = x.shape - ratio, overlap, rotate, d, rd = COMPRESS_RATIO, OVERLAP, ROTATE, HEAD_DIM, ROPE_HEAD_DIM - dtype = x.dtype - x = x.float() - kv = x.view(bsz, -1) @ wkv.T - score = x.view(bsz, -1) @ wgate.T - - should_compress = (START_POS + 1) % ratio == 0 - score = score + ape[START_POS % ratio] - # Non-overlap path - kv_state[:bsz, START_POS % ratio] = kv - score_state[:bsz, START_POS % ratio] = score + ratio, rd = COMPRESS_RATIO, ROPE_HEAD_DIM + + kv = x @ wkv + score = x @ wgate + + should_compress = (start_pos + 1) % ratio == 0 + score = score + ape[start_pos % ratio] + + # Non-overlap: scatter into slot start_pos % ratio + kv_state[:bsz, start_pos % ratio] = kv.squeeze(1) + score_state[:bsz, start_pos % ratio] = score.squeeze(1) if should_compress: kv = (kv_state[:bsz] * score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True) + # No shift for non-overlap + + tensors["kv_state"][:] = kv_state + tensors["score_state"][:] = score_state if not should_compress: - tensors["out"][:] = torch.zeros(B, HEAD_DIM, dtype=torch.bfloat16) return - kv_c = kv.squeeze(1) - kv_c = kv_c * torch.rsqrt(kv_c.square().mean(-1, keepdim=True) + EPS) * norm_w + def rmsnorm(x, w): + x = x.float() + var = x.square().mean(-1, keepdim=True) + x = x * torch.rsqrt(var + EPS) + return (w * x).to(torch.bfloat16) + + kv = rmsnorm(kv.to(torch.bfloat16), norm_w) - half_rd = rd // 2 - x_lo = kv_c[..., -rd:-half_rd] - x_hi = kv_c[..., -half_rd:] + x_pair = kv[..., -rd:].unflatten(-1, (-1, 2)) + x0, x1 = x_pair[..., 0], x_pair[..., 1] cos_v, sin_v = cos.view(-1), sin.view(-1) - y_lo = x_lo * cos_v - x_hi * sin_v - y_hi = x_lo * sin_v + x_hi * cos_v - kv_c = torch.cat([kv_c[..., :-rd], y_lo, y_hi], dim=-1) + y0 = (x0 * cos_v - x1 * sin_v).to(torch.bfloat16) + y1 = (x0 * sin_v + x1 * cos_v).to(torch.bfloat16) + + kv = torch.cat([kv[..., :-rd], torch.stack([y0, y1], dim=-1).flatten(-2)], dim=-1).float() if rotate: - kv_c = (kv_c @ hadamard).to(torch.bfloat16).float() - else: - pass + kv = kv @ hadamard + tensors["kv"][:] = kv - tensors["out"][:] = kv_c.to(torch.bfloat16) + kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1) + + tensors["kv_cache"][:] = kv_cache def build_tensor_specs(): import torch # type: ignore[import] from golden import ScalarSpec, TensorSpec - torch.manual_seed(42) - def init_x(): - return torch.randn(B, S, D) - 0.5 + return torch.randn(B, S, D) * 0.1 def init_kv_state(): return torch.zeros(B, STATE_LEN, OUT_DIM) def init_score_state(): - return torch.full((B, STATE_LEN, OUT_DIM), float("-inf")) + return torch.full((B, STATE_LEN, OUT_DIM), FP32_NEG_INF) def init_wkv(): - return (torch.randn(OUT_DIM, D) - 0.5) / (D ** 0.5) + return torch.randn(D, OUT_DIM) / D ** 0.5 def init_wgate(): - return (torch.randn(OUT_DIM, D) - 0.5) / (D ** 0.5) + return torch.randn(D, OUT_DIM) / D ** 0.5 def init_ape(): - return torch.randn(COMPRESS_RATIO, OUT_DIM) * 0.01 + return torch.randn(COMPRESS_RATIO, OUT_DIM) * 0.1 def init_norm_w(): return torch.ones(HEAD_DIM) def init_cos(): return torch.cos(torch.arange(ROPE_HEAD_DIM // 2).reshape(1, ROPE_HEAD_DIM // 2) * 1e-3) def init_sin(): return torch.sin(torch.arange(ROPE_HEAD_DIM // 2).reshape(1, ROPE_HEAD_DIM // 2) * 1e-3) + def init_odd_select(): + M = torch.zeros((ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2)) + for i in range(ROPE_HEAD_DIM // 2): + M[2*i+1, i] = 1 + return M + def init_even_select(): + M = torch.zeros((ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2)) + for i in range(ROPE_HEAD_DIM // 2): + M[2*i, i] = 1 + return M def init_hadamard(): - return torch.eye(HEAD_DIM) + H = torch.ones((1, 1)) + while H.shape[0] < HEAD_DIM: + H = torch.cat([ + torch.cat([H, H], dim=1), + torch.cat([H, -H], dim=1), + ], dim=0) + return H / (HEAD_DIM ** 0.5) + def init_kv_cache(): + return torch.zeros(B, IDX_KV_LEN, HEAD_DIM) + return [ TensorSpec("x", [B, S, D], torch.bfloat16, init_value=init_x), + TensorSpec("kv", [B, S, HEAD_DIM], torch.float32, is_output=True), TensorSpec("kv_state", [B, STATE_LEN, OUT_DIM], torch.float32, init_value=init_kv_state, is_output=True), TensorSpec("score_state", [B, STATE_LEN, OUT_DIM], torch.float32, init_value=init_score_state, is_output=True), - TensorSpec("wkv", [OUT_DIM, D], torch.bfloat16, init_value=init_wkv), - TensorSpec("wgate", [OUT_DIM, D], torch.bfloat16, init_value=init_wgate), + TensorSpec("wkv", [D, OUT_DIM], torch.bfloat16, init_value=init_wkv), + TensorSpec("wgate", [D, OUT_DIM], torch.bfloat16, init_value=init_wgate), TensorSpec("ape", [COMPRESS_RATIO, OUT_DIM], torch.float32, init_value=init_ape), - TensorSpec("norm_w", [HEAD_DIM], torch.bfloat16, init_value=init_norm_w), - TensorSpec("cos", [1, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_cos), - TensorSpec("sin", [1, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_sin), + TensorSpec("norm_w", [HEAD_DIM], torch.float32, init_value=init_norm_w), + TensorSpec("cos", [1, ROPE_HEAD_DIM // 2], torch.float32, init_value=init_cos), + TensorSpec("sin", [1, ROPE_HEAD_DIM // 2], torch.float32, init_value=init_sin), + TensorSpec("even_select", [ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_even_select), + TensorSpec("odd_select", [ROPE_HEAD_DIM, ROPE_HEAD_DIM // 2], torch.bfloat16, init_value=init_odd_select), TensorSpec("hadamard", [HEAD_DIM, HEAD_DIM], torch.bfloat16, init_value=init_hadamard), + TensorSpec("kv_cache", [B, IDX_KV_LEN, HEAD_DIM], torch.bfloat16, init_value=init_kv_cache, is_output=True), ScalarSpec("start_pos", torch.int32, START_POS), - TensorSpec("out", [B, HEAD_DIM], torch.bfloat16, is_output=True), + ScalarSpec("rotate", torch.bool, True), ] if __name__ == "__main__": import argparse - from golden import RunConfig, run_jit + from golden import RunConfig, bf16_allclose_or_ulp, run_jit parser = argparse.ArgumentParser() parser.add_argument("-p", "--platform", type=str, default="a2a3", @@ -323,14 +412,15 @@ def init_hadamard(): specs=build_tensor_specs(), golden_fn=golden_compressor, config=RunConfig( - rtol=2e-3, - atol=2e-3, + rtol=1e-3, + atol=1e-3, compile=dict(dump_passes=True), runtime=dict( platform=args.platform, device_id=args.device, runtime_profiling=args.runtime_profiling, ), + compare_fn={"kv_cache": bf16_allclose_or_ulp()}, ), ) if not result.passed: