Sample the DSpark draft proposal from q, not argmax, for B2#1
Open
audreyt wants to merge 1 commit into
Open
Conversation
B2 rejection sampling (Chen et al. 2023 / Leviathan et al. 2023) accepts the draft token with probability min(1, p(x)/q(x)), which is only a valid lossless sampler when x is actually drawn from q, the drafter's temperature-scaled distribution. metal_graph_eval_dspark_draft_block always chose the draft token via sample_argmax regardless of DS4_SPEC_TEMP, so the accept/reject math computed p(x)/q(x) against a distribution x was never sampled from. - Add dspark_sample_draft_token: genuine temperature-scaled categorical sample over the full vocab (reuses the existing b2_log_softmax / b2_sample_from_log_probs helpers and sample_rng_f32, so no new RNG or translation unit). - Thread draft_temperature/uint64_t *draft_rng into metal_graph_eval_dspark_draft_block; use the new sampler instead of sample_argmax when B2 is active. Greedy path (temperature <= 0) is byte-for-byte unchanged (falls back to sample_argmax exactly as before). - Resolve DS4_SPEC_TEMP exactly once, at session creation, into a new s->dspark_b2_temp field, and use that cached value at both draft time and accept/reject time instead of querying getenv separately at each site. RNG seeding moves into the same session-creation block, since drafting now happens before the old lazy-seed point ever ran. - Drop `static` from dspark_sample_draft_token and b2_rejection_sample so a white-box statistical test can call them directly. Not part of the public ds4.h surface -- tests/ds4_test.c forward-declares the minimal shape it needs, same pattern as other internal-helper tests in this file. - Add a CPU-only synthetic test (test_dspark_b2_rejection_sampling_unbiased, --dspark-b2-unbiased) proving the sampler's single-token marginal output distribution matches a known target distribution within tolerance when the proposal is genuinely sampled (max_dev_correct=0.0018, N=50000), and diverges sharply when the proposal is naively argmax'd instead (max_dev_biased=0.1426) -- proof the test discriminates the fix from the bug it replaces, not just a passing assertion. Verified: make/make ds4_test build with zero warnings; make test group (dspark-b2-unbiased) passes with the numbers above; live --dspark-speculative-block (real model, greedy path) is unaffected, worst_argmax_gap=0.000; an A/B run with two DS4_SPEC_RNG_SEED values at DS4_SPEC_TEMP=0.8 shows different generated text and different accept/correction patterns even in cycles that were full-accept-no- correction in both seeds, proving the draft proposal itself is now seed-dependent (it was not, before this fix). Scope note: this fixes the B2 proposal-sampling precondition specifically. It does not address the separate batch-verify-vs-single-decode floating point divergence (metal_graph_verify_suffix_tops vs sequential decode) that makes the *greedy* DSpark path not byte-identical to non-speculative decode -- that is a different, pre-existing issue, independent of B2.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
B2 rejection sampling (Chen et al. 2023 / Leviathan et al. 2023) accepts the draft token with probability
min(1, p(x)/q(x)), which is only a valid lossless sampler whenxis actually drawn fromq, the drafter's temperature-scaled distribution.metal_graph_eval_dspark_draft_blockalways chose the draft token viasample_argmaxregardless ofDS4_SPEC_TEMP, so the accept/reject math computedp(x)/q(x)against a distributionxwas never sampled from. This is the issue @lobanov's report and the codex review both flagged (§7.2 in his write-up on antirez#482).What's in this PR
dspark_sample_draft_token: genuine temperature-scaled categorical sample over the full vocab, reusing the existingb2_log_softmax/b2_sample_from_log_probshelpers andsample_rng_f32— no new RNG, no new translation unit, minimal diff against your branch.draft_temperature/uint64_t *draft_rngintometal_graph_eval_dspark_draft_block; use the new sampler instead ofsample_argmaxwhen B2 is active. The greedy path (temperature <= 0) is byte-for-byte unchanged — falls back tosample_argmaxexactly as before.DS4_SPEC_TEMPexactly once, at session creation, into a news->dspark_b2_tempfield, used at both draft time and accept/reject time instead of separategetenvcalls at each site. RNG seeding moves into the same session-creation block, since drafting now happens before the old lazy-seed point ever ran.dspark_sample_draft_tokenandb2_rejection_sampledropstaticso a white-box statistical test can call them directly. Not part of the publicds4.hsurface —tests/ds4_test.cforward-declares the minimal shape it needs.test_dspark_b2_rejection_sampling_unbiased(--dspark-b2-unbiased, no model/GPU needed): draws a real proposal sample from a synthetic drafter distribution and feeds it throughb2_rejection_sample, then checks the resulting single-token marginal matches the target distribution within tolerance — this is the actual losslessness proof, not just "sampling changed."Test results (run this yourself:
make ds4_test && ./ds4_test --dspark-b2-unbiased)max_dev_correct=0.0018: when the proposal is genuinely sampled fromq(viadspark_sample_draft_token), the sampler's output distribution matches the true target distribution to well within statistical noise (binomial std error at N=50000 is ~0.002–0.003).max_dev_biased=0.1426: the same accept/reject code, fed an argmax'd proposal instead (the old bug), diverges by ~70x that noise floor — proving the test discriminates the exact bug this PR fixes, not just asserting something arbitrary.Further end-to-end verification
make/make ds4_test: zero warnings.--dspark-speculative-block(real model, greedy path): unaffected,worst_argmax_gap=0.000.DS4_SPEC_RNG_SEEDvalues atDS4_SPEC_TEMP=0.8, same prompt: generated text and accept/correction patterns diverge starting in cycles that were full-accept-no-correction in both seeds — a no-correction full accept commits the drafted tokens verbatim with zero residual randomness, so this divergence can only happen if the draft proposal itself is seed-dependent. It was not, before this fix (pure deterministic argmax).Scope note
This fixes the B2 proposal-sampling precondition specifically (§7.2). It does not address the separate batch-verify-vs-single-decode floating point divergence (
metal_graph_verify_suffix_topsvs sequential decode) that makes the greedy DSpark path not byte-identical to non-speculative decode — that's a different, pre-existing issue (§7.1 in lobanov's report), independent of B2, and still open.Speed: this fix is about correctness, not throughput. On my M5 Max 128GB with the Q4K Markov drafter, B2 is still slower than baseline on most workloads after this fix — partial/correction-cycle replay remains the dominant cost, as your PR's own "Key findings" section notes. Repetitive structured output remains the one clear net win.