Skip to content

Update DeepSeek V4 indexer W8A8C16 quantization path#268

Merged
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
wuzhf9:dsv4_index
May 13, 2026
Merged

Update DeepSeek V4 indexer W8A8C16 quantization path#268
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
wuzhf9:dsv4_index

Conversation

@wuzhf9
Copy link
Copy Markdown
Contributor

@wuzhf9 wuzhf9 commented May 13, 2026

Summary

  • Quantize qr, wq_b, qr_hadamard, and cache tiles in the runtime indexer path so the projection and score matmuls run through the W8A8C16 flow with explicit dequant scales.
  • Update the indexer test signature and tensor specs to accept INT8 wq_b plus per-output-channel wq_b_scale.
  • Align golden_indexer with the kernel by adding per-row and per-channel INT8 quant helpers and simulating the dequantized W8A8C16 score path.

Related Issues

N/A

## Summary
- Quantize qr, wq_b, qr_hadamard, and cache tiles in the runtime
  indexer path so the projection and score matmuls run through the
  W8A8C16 flow with explicit dequant scales.
- Update the indexer test signature and tensor specs to accept INT8
  wq_b plus per-output-channel wq_b_scale.
- Align golden_indexer with the kernel by adding per-row and
  per-channel INT8 quant helpers and simulating the dequantized
  W8A8C16 score path.

## Related Issues
N/A
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 13, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR extends the DeepSeek-V4 decode indexer to use INT8 quantization for weights and activations in the score path. The indexer kernel signature is updated to accept INT8 quantized weights and a dequant scale tensor; QR and KV cache tiles are quantized per-row/per-tile before INT8 matmuls, and the golden reference implements matching quantization helpers. Test comparison functions are updated to use ULP-tolerant BF16 comparison for quantized outputs.

Changes

DeepSeek-V4 Indexer INT8 Quantization

Layer / File(s) Summary
Quantization Constants and Comparison Function Setup
models/deepseek/v4/indexer.py, models/deepseek/v4/indexer_compressor.py
INT8 quantization constants (INT8_SCALE_MAX, INT8_AMAX_EPS) and bf16_allclose_or_ulp comparison function are imported to support INT8 data validation in both indexer and compressor test harnesses.
Indexer Kernel Signature Updates
models/deepseek/v4/indexer.py
indexer and indexer_test kernel signatures change to accept INT8 quantized wq_b weights and an FP32 per-output-channel wq_b_scale dequant tensor; indexer_test is wired to pass the scale through to the kernel.
QR and Weight Quantization in Indexer
models/deepseek/v4/indexer.py
Within the indexer kernel, qr is reshaped and quantized to INT8 per-row, then INT8 matmulled against INT8 wq_b with dequantization using both qr scales and wq_b_scale to produce BF16 qr_proj. Hadamard outputs are quantized to INT8 using per-head amax-based scales.
Score Path INT8 Quantization
models/deepseek/v4/indexer.py
The score computation pipeline is rewritten to quantize KV cache tiles to INT8 with per-tile scales, perform INT8×INT8 matmul against quantized hadamard q, and dequantize using kv and q-hadamard scales before assembling score_logits.
Golden Reference INT8 Implementation
models/deepseek/v4/indexer.py
The PyTorch golden_indexer adds INT8 quantization helper functions and implements matching quantization logic: q is computed via quantized qr and INT8 wq_b with dequantization, and score is computed via INT8 matmul on per-row quantized q and KV cache followed by dequantization.
Tensor Specs and Runtime Test Configuration
models/deepseek/v4/indexer.py, models/deepseek/v4/indexer_compressor.py
build_tensor_specs quantizes wq_b to INT8 per output channel and adds wq_b_scale; both indexer and compressor test harnesses switch KV cache comparison to use bf16_allclose_or_ulp() for ULP-tolerant validation.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#260: Both PRs modify the DeepSeek-V4 decode indexer implementation and scoring logic, with this PR extending that work to INT8 quantized matmuls and dequantization throughout the score path.
  • hw-native-sys/pypto-lib#264: This PR directly precedes the main PR's adoption of the same bf16_allclose_or_ulp comparator for KV cache validation in test harnesses.

Poem

🐰 With bits compressed to eight, the indexer flies—
Per-row quantized queries meet weights in disguise.
Hadamard outputs glow in INT8's tight embrace,
While golden references keep perfect pace.
Scales and dequants dance as scores align,
A faster path to answers, number by number fine.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 37.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: updating the DeepSeek V4 indexer to use W8A8C16 quantization, which is the primary focus of the changeset.
Description check ✅ Passed The description directly relates to the changeset, outlining the three main objectives: quantizing the runtime path, updating test signatures, and aligning the golden reference implementation with quantization semantics.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements W8A8C16 quantization for the indexer module, transitioning weight and activation tensors to INT8 with INT32 accumulation. Key changes include the addition of per-row and per-channel quantization logic for query projections, Hadamard-transformed queries, and KV cache tiles, followed by appropriate dequantization. Review feedback identifies a redundant allocation of the qr_scale_dq tensor, recommending the use of pl.assemble to populate the pre-allocated tensor rather than reassigning it to avoid unnecessary memory overhead.

qr_proj = pl.create_tensor([T, IDX_N_HEADS * IDX_HEAD_DIM], dtype=pl.BF16)
qr_flat = pl.reshape(qr, [T, Q_LORA])
qr_i8 = pl.create_tensor([T, Q_LORA], dtype=pl.INT8)
qr_scale_dq = pl.create_tensor([T, 1], dtype=pl.FP32)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tensor qr_scale_dq is pre-allocated here using pl.create_tensor, but it is later reassigned at line 103 using pl.reshape. This results in a redundant allocation. If the intention was to use the pre-allocated tensor, you should use pl.assemble or direct slice assignment if supported by the compiler for full replacement. Otherwise, you can remove this line and let line 103 handle the allocation.

qr_a_max = pl.reshape(pl.row_max(qr_a_abs), [1, T])
qr_amax = pl.maximum(qr_amax, qr_a_max)
qr_scale_quant_row = pl.div(pl.full([1, T], dtype=pl.FP32, value=INT8_SCALE_MAX), qr_amax)
qr_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T, 1])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Reassigning qr_scale_dq here with the result of pl.reshape makes the initial allocation at line 93 redundant. To maintain consistency with how other tensors like qr_hadamard_scale_dq are handled in this file, consider using pl.assemble to fill the pre-allocated tensor, or simply remove the redundant allocation at line 93.

Suggested change
qr_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T, 1])
qr_scale_dq = pl.assemble(qr_scale_dq, pl.reshape(pl.recip(qr_scale_quant_row), [T, 1]), [0, 0])

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/indexer.py`:
- Around line 262-312: The INT8 score path assumes DECODE_SEQ==1: score_acc and
qh_scale are computed for S * IDX_N_HEADS but score_logits is allocated with
only IDX_N_HEADS, causing a shape mismatch at score_store/score_weighted_reduce;
either add an explicit guard/assert (e.g. check DECODE_SEQ == 1 at the top of
the kernel and raise/log if not) or propagate S * IDX_N_HEADS through the score
buffers by changing the allocation of score_logits (and any downstream
buffers/reshapes that consume it) to use columns = S * IDX_N_HEADS so the
pl.assemble into score_logits in score_store matches the computed score_tile;
reference symbols: score_logits, score_acc, qh_scale, score_store,
score_weighted_reduce, and DECODE_SEQ.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 37d1289e-5ff9-4422-b0c1-b0b57e91c3dd

📥 Commits

Reviewing files that changed from the base of the PR and between 7e510c3 and 1c58355.

📒 Files selected for processing (2)
  • models/deepseek/v4/indexer.py
  • models/deepseek/v4/indexer_compressor.py

Comment on lines 262 to +312
score_logits = pl.create_tensor([B * MAX_CACHE_BLOCKS * CACHE_TILE, IDX_N_HEADS], dtype=pl.FP32)
score_kv_scale = pl.create_tensor([B * MAX_CACHE_BLOCKS * CACHE_TILE, 1], dtype=pl.FP32)
weighted_score_tiles = pl.create_tensor([B * MAX_CACHE_BLOCKS * S, CACHE_TILE], dtype=pl.FP32)
score_flat = pl.reshape(score, [T, SCORE_LEN])

for b in pl.parallel(B):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="score"):
q0 = b * S * IDX_N_HEADS
kv0 = b * IDX_KV_LEN
for cb in pl.range(cache_blocks):
qr_hadamard_tile = qr_hadamard[q0 : q0 + S * IDX_N_HEADS, :]
kv_cache_tile = kv_cache_flat[kv0 + cb * CACHE_TILE : kv0 + (cb + 1) * CACHE_TILE, :]
score_logits_tile = pl.matmul(kv_cache_tile, qr_hadamard_tile, out_dtype=pl.FP32, b_trans=True)
score_logits = pl.assemble(
score_logits,
score_logits_tile,
[(b * MAX_CACHE_BLOCKS + cb) * CACHE_TILE, 0],
)
t0 = b * S
q0 = b * S * IDX_N_HEADS
kv0 = b * IDX_KV_LEN

with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_weighted_reduce"):
t0 = b * S
with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_init"):
neg_inf_score = pl.full([S, SCORE_LEN], dtype=pl.FP32, value=FP32_NEG_INF)
score_flat = pl.assemble(score_flat, neg_inf_score, [t0, 0])
for cb in pl.range(cache_blocks):
cache0 = cb * CACHE_TILE
valid_len = pl.min(CACHE_TILE, cache_len - cache0)

for cb in pl.range(cache_blocks):
cache0 = cb * CACHE_TILE
valid_len = pl.min(CACHE_TILE, cache_len - cache0)
score_row0 = (b * MAX_CACHE_BLOCKS + cb) * CACHE_TILE

with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_quant"):
kv_cache_tile_i8 = pl.create_tensor([CACHE_TILE, IDX_HEAD_DIM], dtype=pl.INT8)
kv_amax = pl.full([1, CACHE_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS)
for h0 in pl.range(0, IDX_HEAD_DIM, HEAD_DIM_CHUCK):
kv_a_tile = kv_cache_flat[kv0 + cache0 : kv0 + cache0 + CACHE_TILE, h0 : h0 + HEAD_DIM_CHUCK]
kv_a_f32 = pl.cast(kv_a_tile, target_type=pl.FP32)
kv_a_abs = pl.maximum(kv_a_f32, pl.neg(kv_a_f32))
kv_a_max = pl.reshape(pl.row_max(kv_a_abs), [1, CACHE_TILE])
kv_amax = pl.maximum(kv_amax, kv_a_max)
kv_scale_quant_row = pl.div(pl.full([1, CACHE_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), kv_amax)
kv_cache_scale_dq = pl.reshape(pl.recip(kv_scale_quant_row), [CACHE_TILE, 1])
kv_scale_quant = pl.reshape(kv_scale_quant_row, [CACHE_TILE, 1])
for h1 in pl.range(0, IDX_HEAD_DIM, HEAD_DIM_CHUCK):
kv_q_tile = kv_cache_flat[kv0 + cache0 : kv0 + cache0 + CACHE_TILE, h1 : h1 + HEAD_DIM_CHUCK]
kv_q_f32 = pl.cast(kv_q_tile, target_type=pl.FP32)
kv_q_scaled = pl.row_expand_mul(kv_q_f32, kv_scale_quant)
kv_q_i32 = pl.cast(kv_q_scaled, target_type=pl.INT32, mode="round")
kv_q_half = pl.cast(kv_q_i32, target_type=pl.FP16, mode="round")
kv_q_i8 = pl.cast(kv_q_half, target_type=pl.INT8, mode="trunc")
kv_cache_tile_i8 = pl.assemble(kv_cache_tile_i8, kv_q_i8, [0, h1])
score_kv_scale = pl.assemble(score_kv_scale, kv_cache_scale_dq, [score_row0, 0])

with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_accum"):
qr_hadamard_tile = qr_hadamard_i8[q0 : q0 + S * IDX_N_HEADS, :]
score_acc = pl.matmul(kv_cache_tile_i8, qr_hadamard_tile, out_dtype=pl.INT32, b_trans=True)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_store"):
kv_cache_scale_dq = score_kv_scale[score_row0 : score_row0 + CACHE_TILE, :]
qh_scale = pl.reshape(qr_hadamard_scale_dq[q0 : q0 + S * IDX_N_HEADS, :], [1, S * IDX_N_HEADS])
score_tile = pl.cast(score_acc, target_type=pl.FP32, mode="none")
score_tile = pl.col_expand_mul(pl.row_expand_mul(score_tile, kv_cache_scale_dq), qh_scale)
score_logits = pl.assemble(score_logits, score_tile, [score_row0, 0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Encode the single-token decode assumption in the INT8 score path.

score_acc and qh_scale are computed for S * IDX_N_HEADS, but score_logits is still allocated with only IDX_N_HEADS columns. This only works when DECODE_SEQ == 1; any multi-token decode config will mis-shape at score_store / score_weighted_reduce. Please either guard this kernel explicitly or carry S * IDX_N_HEADS through the score buffers.

Minimal guard if multi-token decode is intentionally unsupported
 `@pl.jit.inline`
 def indexer(
     x: pl.Tensor[[B, S, D], pl.BF16],
@@
     inner_rotate: pl.Scalar[pl.BOOL],
 ):
+    assert S == 1, "DeepSeek-V4 decode indexer currently assumes DECODE_SEQ == 1"
     # TODO: kernel implementation
     cache_len = (start_pos + S) // COMPRESS_RATIO
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/indexer.py` around lines 262 - 312, The INT8 score path
assumes DECODE_SEQ==1: score_acc and qh_scale are computed for S * IDX_N_HEADS
but score_logits is allocated with only IDX_N_HEADS, causing a shape mismatch at
score_store/score_weighted_reduce; either add an explicit guard/assert (e.g.
check DECODE_SEQ == 1 at the top of the kernel and raise/log if not) or
propagate S * IDX_N_HEADS through the score buffers by changing the allocation
of score_logits (and any downstream buffers/reshapes that consume it) to use
columns = S * IDX_N_HEADS so the pl.assemble into score_logits in score_store
matches the computed score_tile; reference symbols: score_logits, score_acc,
qh_scale, score_store, score_weighted_reduce, and DECODE_SEQ.

@zhangqi-chen zhangqi-chen merged commit 64668f8 into hw-native-sys:main May 13, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants