Update DeepSeek V4 indexer W8A8C16 quantization path#268
Conversation
## Summary - Quantize qr, wq_b, qr_hadamard, and cache tiles in the runtime indexer path so the projection and score matmuls run through the W8A8C16 flow with explicit dequant scales. - Update the indexer test signature and tensor specs to accept INT8 wq_b plus per-output-channel wq_b_scale. - Align golden_indexer with the kernel by adding per-row and per-channel INT8 quant helpers and simulating the dequantized W8A8C16 score path. ## Related Issues N/A
📝 WalkthroughWalkthroughThis PR extends the DeepSeek-V4 decode indexer to use INT8 quantization for weights and activations in the score path. The indexer kernel signature is updated to accept INT8 quantized weights and a dequant scale tensor; QR and KV cache tiles are quantized per-row/per-tile before INT8 matmuls, and the golden reference implements matching quantization helpers. Test comparison functions are updated to use ULP-tolerant BF16 comparison for quantized outputs. ChangesDeepSeek-V4 Indexer INT8 Quantization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements W8A8C16 quantization for the indexer module, transitioning weight and activation tensors to INT8 with INT32 accumulation. Key changes include the addition of per-row and per-channel quantization logic for query projections, Hadamard-transformed queries, and KV cache tiles, followed by appropriate dequantization. Review feedback identifies a redundant allocation of the qr_scale_dq tensor, recommending the use of pl.assemble to populate the pre-allocated tensor rather than reassigning it to avoid unnecessary memory overhead.
| qr_proj = pl.create_tensor([T, IDX_N_HEADS * IDX_HEAD_DIM], dtype=pl.BF16) | ||
| qr_flat = pl.reshape(qr, [T, Q_LORA]) | ||
| qr_i8 = pl.create_tensor([T, Q_LORA], dtype=pl.INT8) | ||
| qr_scale_dq = pl.create_tensor([T, 1], dtype=pl.FP32) |
There was a problem hiding this comment.
The tensor qr_scale_dq is pre-allocated here using pl.create_tensor, but it is later reassigned at line 103 using pl.reshape. This results in a redundant allocation. If the intention was to use the pre-allocated tensor, you should use pl.assemble or direct slice assignment if supported by the compiler for full replacement. Otherwise, you can remove this line and let line 103 handle the allocation.
| qr_a_max = pl.reshape(pl.row_max(qr_a_abs), [1, T]) | ||
| qr_amax = pl.maximum(qr_amax, qr_a_max) | ||
| qr_scale_quant_row = pl.div(pl.full([1, T], dtype=pl.FP32, value=INT8_SCALE_MAX), qr_amax) | ||
| qr_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T, 1]) |
There was a problem hiding this comment.
Reassigning qr_scale_dq here with the result of pl.reshape makes the initial allocation at line 93 redundant. To maintain consistency with how other tensors like qr_hadamard_scale_dq are handled in this file, consider using pl.assemble to fill the pre-allocated tensor, or simply remove the redundant allocation at line 93.
| qr_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T, 1]) | |
| qr_scale_dq = pl.assemble(qr_scale_dq, pl.reshape(pl.recip(qr_scale_quant_row), [T, 1]), [0, 0]) |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/indexer.py`:
- Around line 262-312: The INT8 score path assumes DECODE_SEQ==1: score_acc and
qh_scale are computed for S * IDX_N_HEADS but score_logits is allocated with
only IDX_N_HEADS, causing a shape mismatch at score_store/score_weighted_reduce;
either add an explicit guard/assert (e.g. check DECODE_SEQ == 1 at the top of
the kernel and raise/log if not) or propagate S * IDX_N_HEADS through the score
buffers by changing the allocation of score_logits (and any downstream
buffers/reshapes that consume it) to use columns = S * IDX_N_HEADS so the
pl.assemble into score_logits in score_store matches the computed score_tile;
reference symbols: score_logits, score_acc, qh_scale, score_store,
score_weighted_reduce, and DECODE_SEQ.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 37d1289e-5ff9-4422-b0c1-b0b57e91c3dd
📒 Files selected for processing (2)
models/deepseek/v4/indexer.pymodels/deepseek/v4/indexer_compressor.py
| score_logits = pl.create_tensor([B * MAX_CACHE_BLOCKS * CACHE_TILE, IDX_N_HEADS], dtype=pl.FP32) | ||
| score_kv_scale = pl.create_tensor([B * MAX_CACHE_BLOCKS * CACHE_TILE, 1], dtype=pl.FP32) | ||
| weighted_score_tiles = pl.create_tensor([B * MAX_CACHE_BLOCKS * S, CACHE_TILE], dtype=pl.FP32) | ||
| score_flat = pl.reshape(score, [T, SCORE_LEN]) | ||
|
|
||
| for b in pl.parallel(B): | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="score"): | ||
| q0 = b * S * IDX_N_HEADS | ||
| kv0 = b * IDX_KV_LEN | ||
| for cb in pl.range(cache_blocks): | ||
| qr_hadamard_tile = qr_hadamard[q0 : q0 + S * IDX_N_HEADS, :] | ||
| kv_cache_tile = kv_cache_flat[kv0 + cb * CACHE_TILE : kv0 + (cb + 1) * CACHE_TILE, :] | ||
| score_logits_tile = pl.matmul(kv_cache_tile, qr_hadamard_tile, out_dtype=pl.FP32, b_trans=True) | ||
| score_logits = pl.assemble( | ||
| score_logits, | ||
| score_logits_tile, | ||
| [(b * MAX_CACHE_BLOCKS + cb) * CACHE_TILE, 0], | ||
| ) | ||
| t0 = b * S | ||
| q0 = b * S * IDX_N_HEADS | ||
| kv0 = b * IDX_KV_LEN | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_weighted_reduce"): | ||
| t0 = b * S | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_init"): | ||
| neg_inf_score = pl.full([S, SCORE_LEN], dtype=pl.FP32, value=FP32_NEG_INF) | ||
| score_flat = pl.assemble(score_flat, neg_inf_score, [t0, 0]) | ||
| for cb in pl.range(cache_blocks): | ||
| cache0 = cb * CACHE_TILE | ||
| valid_len = pl.min(CACHE_TILE, cache_len - cache0) | ||
|
|
||
| for cb in pl.range(cache_blocks): | ||
| cache0 = cb * CACHE_TILE | ||
| valid_len = pl.min(CACHE_TILE, cache_len - cache0) | ||
| score_row0 = (b * MAX_CACHE_BLOCKS + cb) * CACHE_TILE | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_quant"): | ||
| kv_cache_tile_i8 = pl.create_tensor([CACHE_TILE, IDX_HEAD_DIM], dtype=pl.INT8) | ||
| kv_amax = pl.full([1, CACHE_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) | ||
| for h0 in pl.range(0, IDX_HEAD_DIM, HEAD_DIM_CHUCK): | ||
| kv_a_tile = kv_cache_flat[kv0 + cache0 : kv0 + cache0 + CACHE_TILE, h0 : h0 + HEAD_DIM_CHUCK] | ||
| kv_a_f32 = pl.cast(kv_a_tile, target_type=pl.FP32) | ||
| kv_a_abs = pl.maximum(kv_a_f32, pl.neg(kv_a_f32)) | ||
| kv_a_max = pl.reshape(pl.row_max(kv_a_abs), [1, CACHE_TILE]) | ||
| kv_amax = pl.maximum(kv_amax, kv_a_max) | ||
| kv_scale_quant_row = pl.div(pl.full([1, CACHE_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), kv_amax) | ||
| kv_cache_scale_dq = pl.reshape(pl.recip(kv_scale_quant_row), [CACHE_TILE, 1]) | ||
| kv_scale_quant = pl.reshape(kv_scale_quant_row, [CACHE_TILE, 1]) | ||
| for h1 in pl.range(0, IDX_HEAD_DIM, HEAD_DIM_CHUCK): | ||
| kv_q_tile = kv_cache_flat[kv0 + cache0 : kv0 + cache0 + CACHE_TILE, h1 : h1 + HEAD_DIM_CHUCK] | ||
| kv_q_f32 = pl.cast(kv_q_tile, target_type=pl.FP32) | ||
| kv_q_scaled = pl.row_expand_mul(kv_q_f32, kv_scale_quant) | ||
| kv_q_i32 = pl.cast(kv_q_scaled, target_type=pl.INT32, mode="round") | ||
| kv_q_half = pl.cast(kv_q_i32, target_type=pl.FP16, mode="round") | ||
| kv_q_i8 = pl.cast(kv_q_half, target_type=pl.INT8, mode="trunc") | ||
| kv_cache_tile_i8 = pl.assemble(kv_cache_tile_i8, kv_q_i8, [0, h1]) | ||
| score_kv_scale = pl.assemble(score_kv_scale, kv_cache_scale_dq, [score_row0, 0]) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_accum"): | ||
| qr_hadamard_tile = qr_hadamard_i8[q0 : q0 + S * IDX_N_HEADS, :] | ||
| score_acc = pl.matmul(kv_cache_tile_i8, qr_hadamard_tile, out_dtype=pl.INT32, b_trans=True) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_store"): | ||
| kv_cache_scale_dq = score_kv_scale[score_row0 : score_row0 + CACHE_TILE, :] | ||
| qh_scale = pl.reshape(qr_hadamard_scale_dq[q0 : q0 + S * IDX_N_HEADS, :], [1, S * IDX_N_HEADS]) | ||
| score_tile = pl.cast(score_acc, target_type=pl.FP32, mode="none") | ||
| score_tile = pl.col_expand_mul(pl.row_expand_mul(score_tile, kv_cache_scale_dq), qh_scale) | ||
| score_logits = pl.assemble(score_logits, score_tile, [score_row0, 0]) |
There was a problem hiding this comment.
Encode the single-token decode assumption in the INT8 score path.
score_acc and qh_scale are computed for S * IDX_N_HEADS, but score_logits is still allocated with only IDX_N_HEADS columns. This only works when DECODE_SEQ == 1; any multi-token decode config will mis-shape at score_store / score_weighted_reduce. Please either guard this kernel explicitly or carry S * IDX_N_HEADS through the score buffers.
Minimal guard if multi-token decode is intentionally unsupported
`@pl.jit.inline`
def indexer(
x: pl.Tensor[[B, S, D], pl.BF16],
@@
inner_rotate: pl.Scalar[pl.BOOL],
):
+ assert S == 1, "DeepSeek-V4 decode indexer currently assumes DECODE_SEQ == 1"
# TODO: kernel implementation
cache_len = (start_pos + S) // COMPRESS_RATIO🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/indexer.py` around lines 262 - 312, The INT8 score path
assumes DECODE_SEQ==1: score_acc and qh_scale are computed for S * IDX_N_HEADS
but score_logits is allocated with only IDX_N_HEADS, causing a shape mismatch at
score_store/score_weighted_reduce; either add an explicit guard/assert (e.g.
check DECODE_SEQ == 1 at the top of the kernel and raise/log if not) or
propagate S * IDX_N_HEADS through the score buffers by changing the allocation
of score_logits (and any downstream buffers/reshapes that consume it) to use
columns = S * IDX_N_HEADS so the pl.assemble into score_logits in score_store
matches the computed score_tile; reference symbols: score_logits, score_acc,
qh_scale, score_store, score_weighted_reduce, and DECODE_SEQ.
Summary
Related Issues
N/A