Align DeepSeek V4 INT8/BF16 cast to round-to-nearest-even#279
Conversation
Switches all kernel narrow casts in v4 to round-to-nearest-even (RTNE) to match the upstream reference cast strategy and the golden side's torch defaults: * FP32 -> INT32 in INT8 quant: mode="round" (half-away-from-zero) -> mode="rint" (RTNE) in qkv_proj_rope, sparse_attn, indexer (x3), moe_expert (x4). * FP32 -> BF16: explicit mode="rint" added at every call site (44 spots across hc_pre, hc_post, qkv_proj_rope, sparse_attn, indexer, indexer_compressor, compressor_ratio4/128, moe_expert, attention_swa, attention_hca). Previously inherited pl.cast default "round" which is half-away-from-zero; the new explicit mode matches torch's default RTNE used in goldens via .to(torch.bfloat16). Golden helpers replaced: * _round_half_away_from_zero / round_half_away_from_zero now use torch.round (RTNE by default) directly. Removed the local helpers in qkv_proj_rope, sparse_attn, indexer, moe_expert, attention_swa, attention_hca. The INT32 -> FP16 step in the INT8 quant chain (mode="round") and the FP16 -> INT8 trunc step are unchanged: they already match the upstream reference's CAST_ROUND + CAST_TRUNC semantics. The reference also sets satmode=ON on the final INT8 cast; pl.cast does not expose a satmode parameter today, but INT8_SCALE_MAX = 127.0 mathematically rules out overflow so the gap is benign.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (11)
📝 WalkthroughWalkthroughThis PR standardizes numerical rounding semantics across all DeepSeek-V4 decode kernels and their PyTorch golden references. Golden reference helpers replace custom half-away-from-zero rounding with ChangesDeepSeek-V4 Rounding Semantics Standardization
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request standardizes rounding behavior across multiple model components, including attention mechanisms (HCA, SWA, Sparse), MoE experts, and indexers. The changes replace custom 'round half away from zero' implementations with standard torch.round and explicitly set the rounding mode to 'rint' (round to nearest, ties to even) for pl.cast operations involving BF16 and INT32 conversions. These updates ensure consistent numerical behavior during quantization and type casting across the codebase. No review comments were provided for this pull request, and I have no additional feedback.
Summary
Aligns all kernel narrow casts in v4 to round-to-nearest-even (RTNE), matching the upstream reference cast strategy and the golden side's torch defaults.
mode="round"(half-away-from-zero) →mode="rint"(RTNE). Affected:qkv_proj_rope(1),sparse_attn(1),indexer(3),moe_expert(4).mode="rint"at every call site (44 spots acrosshc_pre,hc_post,qkv_proj_rope,sparse_attn,indexer,indexer_compressor,compressor_ratio4/128,moe_expert,attention_swa,attention_hca). Previously inheritedpl.cast's defaultmode="round"(half-away-from-zero); the new explicitmode="rint"matches torch's default RTNE used in goldens via.to(torch.bfloat16)._round_half_away_from_zero/round_half_away_from_zerowithtorch.round(RTNE by default). Removed the helpers inqkv_proj_rope,sparse_attn,indexer,moe_expert,attention_swa,attention_hca.The INT32 → FP16 step (still
mode="round") and the FP16 → INT8 trunc step are unchanged: they already match the reference'sCAST_ROUND+CAST_TRUNCsemantics.Validation (device 8,
a2a3)All single-stage kernels still pass under their existing tolerances. Notable improvement:
compressor_ratio128previously failed strictallcloseonkv(32/8192 ≈ 0.39% bad points) andkv_cache(1 BF16 point at 5-ULP); both now pass with the RTNE alignment.compressor_ratio4remains borderline (kv11/8192 ≈ 0.13% bad under strict allclose, well within the 0.5% outlier escape used by attention-class outputs).End-to-end
attention_swaandattention_hcaare out of scope for this PR — their fused-kernel tolerance is a separate design question (cumulative single-stage error, no upstream end-to-end counterpart).Known caveat —
pl.castsatmodeThe reference also sets
satmode=ONon the final INT8 cast;pl.castdoes not expose asatmodeparameter in the currentpyptoframework.INT8_SCALE_MAX = 127.0mathematically rules out overflow so the gap is benign for v4. Tracking framework-side as a follow-up.Unrelated pre-existing issue (not introduced by this PR)
moe_expert.pyfails to compile on both this branch andupstream/mainwithpl.col_expand_mul: second argument requires ≥2 dimensions, got 1. This is a pypto framework regression incol_expand_mul's API and is independent of the cast alignment work.Related Issues