From 13f6ca26d07596b81c4164a5a47e9b77a914247c Mon Sep 17 00:00:00 2001 From: yongjunlee Date: Thu, 14 May 2026 05:34:38 +0000 Subject: [PATCH 1/2] feat(spec-decode): add n-gram (prompt-lookup) speculative drafter Adds `SpeculativeAlgorithm.NGRAM`, a draft-model-free prompt-lookup drafter for chain speculative decoding. The drafter keeps a CPU-side token history per request-pool slot, finds the longest suffix-matching n-gram in that history, and proposes the continuation tokens as `[last_verified, d1, ..., dK]`. Reuses the existing target-side chain verify path without adding a draft model, draft KV cache, drafter attention backend, or new verify kernel. Adds startup validation for NGRAM-specific constraints: no draft model path, `topk == 1`, `num_draft_tokens == num_steps + 1`, no PD disaggregation, prefix caching and chunked prefill disabled, eager mode forced. The lookup core is split into a pure-numpy module so the algorithm and batched proposer can be unit-tested without native kernel builds. Tests: KMP suffix lookup, batched proposer row layout / padding / shape validation, NGRAM resolve and reject paths in CLI compat. Signed-off-by: yongjunlee --- docs/configuration/server.md | 24 +- .../runtime/execution/drafter/ngram.py | 273 ++++++++++++++++++ .../runtime/execution/drafter/ngram_lookup.py | 138 +++++++++ .../runtime/execution/model_executor.py | 85 ++++-- .../runtime/spec_decode/algorithm.py | 3 + .../tokenspeed/runtime/utils/server_args.py | 103 ++++++- test/runtime/test_cli_config_compat.py | 140 +++++++++ test/runtime/test_spec_decode_ngram.py | 187 ++++++++++++ 8 files changed, 924 insertions(+), 29 deletions(-) create mode 100644 python/tokenspeed/runtime/execution/drafter/ngram.py create mode 100644 python/tokenspeed/runtime/execution/drafter/ngram_lookup.py create mode 100644 test/runtime/test_spec_decode_ngram.py diff --git a/docs/configuration/server.md b/docs/configuration/server.md index 86b05e955..8b7bcf70a 100644 --- a/docs/configuration/server.md +++ b/docs/configuration/server.md @@ -116,17 +116,39 @@ Common parser values include `kimi_k2` and `gpt-oss`. | Parameter | Purpose | | --- | --- | | `--speculative-config` | JSON speculative decoding configuration. | -| `--speculative-algorithm` | Speculative algorithm, such as `EAGLE3` or `MTP`. | +| `--speculative-algorithm` | Speculative algorithm: `EAGLE3`, `MTP`, or `NGRAM`. | | `--speculative-draft-model-path` | Draft model path or repo ID. | | `--speculative-draft-model-quantization` | Draft model quantization. Defaults to `unquant`. | | `--speculative-num-steps` | Number of draft model steps. Defaults to `3`. | | `--speculative-num-draft-tokens` | Number of draft tokens. Defaults to `--speculative-num-steps + 1`. | | `--speculative-eagle-topk` | EAGLE top-k. Defaults to `1`. | +| `--speculative-ngram-min` | Minimum n-gram length (NGRAM only). Defaults to `1`. | +| `--speculative-ngram-max` | Maximum n-gram length (NGRAM only). Defaults to `3`. | | `--eagle3-layers-to-capture` | EAGLE3 layers to capture. | Prefer `--speculative-config` for recipe-style launches because it keeps method, draft model, and token count together. +### N-gram (prompt-lookup) Speculative Decoding + +`--speculative-algorithm NGRAM` runs a draft-model-free proposer that matches +the longest suffix-ngram in each request's running token history (capped by +`--speculative-ngram-max`) and speculates the tokens that follow the rightmost +match. It reuses the chain verify path, so no extra verify kernel is needed. + +The first release is intentionally narrow: + +- Single-rank only; running under PD disaggregation is rejected at startup. +- The drafter runs outside the captured CUDA graph, so `--enforce-eager` is + auto-enabled with a warning. +- `enable_prefix_caching`, `enable_kvstore`, and chunked prefill are + auto-disabled because the proposer keeps its own per-request token history + and prefix-cache hits would skip the prefill path it relies on. +- `--speculative-eagle-topk` must be `1` (chain only), and + `--speculative-num-draft-tokens` must equal `--speculative-num-steps + 1`. + +JSON form: `--speculative-config '{"method":"ngram","num_speculative_tokens":3}'`. + ## Observability | Parameter | Purpose | diff --git a/python/tokenspeed/runtime/execution/drafter/ngram.py b/python/tokenspeed/runtime/execution/drafter/ngram.py new file mode 100644 index 000000000..d030d5466 --- /dev/null +++ b/python/tokenspeed/runtime/execution/drafter/ngram.py @@ -0,0 +1,273 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""N-gram (prompt-lookup) speculative drafter. + +Draft tokens are proposed by matching the suffix of each request's running +token history against earlier positions in the same history (KMP-style +longest prefix-as-suffix search, capped by ``max_ngram``). Tokens that +follow the matched window become the speculative draft for the next round. + +This drafter is CPU-only: no draft model, no draft KV cache, no draft +attention backend. The chain greedy / chain stochastic verify kernels on +the target side are unchanged. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import torch +from typing_extensions import override + +from tokenspeed.runtime.execution.drafter.base import BaseDrafter +from tokenspeed.runtime.execution.drafter.ngram_lookup import propose_batch_into +from tokenspeed.runtime.execution.forward_batch_info import ForwardMode +from tokenspeed.runtime.utils.nvtx import nvtx_range + +if TYPE_CHECKING: + from tokenspeed.runtime.execution.context import ForwardContext + from tokenspeed.runtime.execution.input_buffer import InputBuffers + from tokenspeed.runtime.execution.runtime_states import RuntimeStates + from tokenspeed.runtime.layers.logits_processor import LogitsProcessorOutput + + +class NgramDrafter(BaseDrafter): + """Prompt-lookup speculative drafter (no draft model). + + Maintains a CPU-side per-request token history keyed by + request-pool index, mirroring the slot semantics used elsewhere in + the executor. On each ``run()`` it appends the freshly accepted + tokens to the matching slot, runs a KMP-based suffix-ngram lookup + per request, and stages the proposed ``[last_verified, d1, ..., + d_K]`` row for the next round's verify input. + """ + + def __init__( + self, + spec_num_tokens: int, + spec_num_steps: int, + runtime_states: RuntimeStates, + input_buffers: InputBuffers, + max_context_len: int, + vocab_size: int | None = None, + min_ngram: int = 1, + max_ngram: int = 3, + ) -> None: + super().__init__( + spec_num_tokens=spec_num_tokens, + spec_num_steps=spec_num_steps, + draft_model_runner=None, + runtime_states=runtime_states, + input_buffers=input_buffers, + page_size=None, + req_to_page=None, + attn_backend=None, + token_to_kv_pool=None, + vocab_size=vocab_size, + ) + + if min_ngram < 1: + raise ValueError(f"min_ngram must be >= 1, got {min_ngram}") + if max_ngram < min_ngram: + raise ValueError( + f"max_ngram ({max_ngram}) must be >= min_ngram ({min_ngram})" + ) + + self.min_ngram = int(min_ngram) + self.max_ngram = int(max_ngram) + self.max_context_len = int(max_context_len) + self.device = runtime_states.device + + pool_capacity = runtime_states.valid_cache_lengths.shape[0] + self.history = np.zeros( + (pool_capacity, self.max_context_len), dtype=np.int32 + ) + self.history_len = np.zeros((pool_capacity,), dtype=np.int32) + + # Staging buffers for batched H2D of the next-round inputs. + self._next_input_np = np.zeros( + (input_buffers.max_bs, spec_num_tokens), dtype=np.int32 + ) + self._next_input_pinned = torch.empty( + (input_buffers.max_bs, spec_num_tokens), + dtype=torch.int32, + pin_memory=(self.device == "cuda"), + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _reset_slot(self, pool_idx: int) -> None: + self.history_len[pool_idx] = 0 + + def _append_to_slot(self, pool_idx: int, tokens: np.ndarray) -> None: + if tokens.size == 0: + return + cur = int(self.history_len[pool_idx]) + new_total = cur + tokens.size + cap = self.max_context_len + if new_total <= cap: + self.history[pool_idx, cur:new_total] = tokens + self.history_len[pool_idx] = new_total + else: + # Slide the window: keep the most recent ``cap`` tokens. + combined = np.concatenate( + [self.history[pool_idx, :cur], tokens.astype(np.int32, copy=False)] + ) + tail = combined[-cap:] + self.history[pool_idx, : tail.size] = tail + self.history_len[pool_idx] = tail.size + + # ------------------------------------------------------------------ + # BaseDrafter contract + # ------------------------------------------------------------------ + + @override + def get_candidates( + self, + base_ctx: ForwardContext, + ) -> torch.Tensor | None: + # Identical layout to EAGLE: verify reads the [bs, spec_num_tokens] + # window that was written into input_ids_buf by fill_input_buffers. + if not ( + base_ctx.forward_mode.is_decode() + or base_ctx.forward_mode.is_target_verify() + ): + return None + return self.input_buffers.input_ids_buf[: base_ctx.input_num_tokens].reshape( + base_ctx.bs, self.spec_num_tokens + ) + + @override + def draft(self, *_args, **_kwargs) -> torch.Tensor: + # Drafting in this proposer is part of ``run()`` (history update + + # ngram lookup are tightly coupled). Keep the abstract method + # satisfied without exposing a separately-callable surface. + raise NotImplementedError( + "NgramDrafter does not expose a standalone draft(); use run()." + ) + + @override + @nvtx_range("ngram_drafter", color="purple") + def run( + self, + base_ctx: ForwardContext, + logits_output: LogitsProcessorOutput, + output_tokens: torch.Tensor, + accept_lengths: torch.Tensor, + ) -> torch.Tensor: + del logits_output # unused; ngram drafter ignores hidden states. + + bs = base_ctx.bs + # The drafter intentionally runs outside the CUDA-graph capture + # path (executor forces enforce_eager when NGRAM is active), so + # these D2H syncs are acceptable. + pool_indices = ( + self.input_buffers.req_pool_indices_buf[:bs].to("cpu").numpy() + ) + + self._update_history(base_ctx, output_tokens, accept_lengths, pool_indices) + self._propose(bs, pool_indices) + + staging = self._next_input_pinned[:bs] + staging.copy_(torch.from_numpy(self._next_input_np[:bs])) + return staging.to(self.device, non_blocking=True) + + # ------------------------------------------------------------------ + # History bookkeeping + # ------------------------------------------------------------------ + + def _update_history( + self, + base_ctx: ForwardContext, + output_tokens: torch.Tensor, + accept_lengths: torch.Tensor, + pool_indices: np.ndarray, + ) -> None: + bs = base_ctx.bs + + if base_ctx.forward_mode == ForwardMode.EXTEND: + num_extends = base_ctx.num_extends + total = base_ctx.input_num_tokens + input_ids = self.input_buffers.input_ids_buf[:total].to("cpu").numpy() + input_lengths = ( + self.input_buffers.input_lengths_buf[:bs].to("cpu").numpy() + ) + # extend_prefix_lens is only populated for prefill rows (first + # ``num_extends`` entries) per the C++ scheduler's + # FlatForwardOperation. A zero entry marks the first chunk of + # a fresh prompt; reset the slot before appending. + if num_extends > 0: + extend_prefix_lens = ( + self.input_buffers.extend_prefix_lens_buf[:num_extends] + .to("cpu") + .numpy() + ) + else: + extend_prefix_lens = np.empty((0,), dtype=np.int32) + sampled = output_tokens.to("cpu").numpy().reshape(-1) + append_sampled = not self.input_buffers.all_extends_mid_chunk + + offset = 0 + for i in range(bs): + pool_idx = int(pool_indices[i]) + length = int(input_lengths[i]) + + is_prefill_row = i < num_extends + if is_prefill_row and int(extend_prefix_lens[i]) == 0: + self._reset_slot(pool_idx) + + self._append_to_slot(pool_idx, input_ids[offset : offset + length]) + if append_sampled and i < sampled.size: + self._append_to_slot(pool_idx, sampled[i : i + 1]) + offset += length + return + + # TARGET_VERIFY: output_tokens is laid out as (bs * spec_num_tokens,) + # and accept_lengths tells us how many of those columns were + # accepted per request (1..N). + verified = ( + output_tokens.to("cpu").numpy().reshape(bs, self.spec_num_tokens) + ) + accepted_n = accept_lengths.to("cpu").numpy().astype(np.int32) + for i in range(bs): + pool_idx = int(pool_indices[i]) + n = int(accepted_n[i]) + if n <= 0: + continue + self._append_to_slot(pool_idx, verified[i, :n]) + + # ------------------------------------------------------------------ + # Proposal + # ------------------------------------------------------------------ + + def _propose(self, bs: int, pool_indices: np.ndarray) -> None: + propose_batch_into( + history=self.history, + history_len=self.history_len, + pool_indices=pool_indices[:bs], + out=self._next_input_np[:bs], + min_ngram=self.min_ngram, + max_ngram=self.max_ngram, + spec_num_steps=self.spec_num_steps, + ) diff --git a/python/tokenspeed/runtime/execution/drafter/ngram_lookup.py b/python/tokenspeed/runtime/execution/drafter/ngram_lookup.py new file mode 100644 index 000000000..0fd802685 --- /dev/null +++ b/python/tokenspeed/runtime/execution/drafter/ngram_lookup.py @@ -0,0 +1,138 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Pure-numpy core for the n-gram (prompt-lookup) speculative drafter. + +Split out from ``ngram.py`` so the algorithm can be exercised without +pulling in the runtime/torch/tokenspeed_kernel import chain. +""" + +from __future__ import annotations + +import numpy as np + + +def find_longest_matched_ngram_and_propose_tokens( + origin_tokens: np.ndarray, + min_ngram: int, + max_ngram: int, + k: int, +) -> np.ndarray: + """Find the longest ngram suffix-match in ``origin_tokens`` of length + within ``[min_ngram, max_ngram]`` and return up to ``k`` tokens that + follow the rightmost match. + + Returns an empty array when ``len(origin_tokens) < min_ngram``, ``k <= + 0``, or no ngram of length >= ``min_ngram`` is found. Matches are + ranked by the longest suffix length found within ``max_ngram``; ties + follow the scan order of the reversed KMP pass. + """ + total = origin_tokens.shape[0] + if total < min_ngram or k <= 0: + return np.empty((0,), dtype=origin_tokens.dtype) + + # Work on the reversed token sequence so that "rightmost match in + # original" becomes "match closest to the front of the reversed + # sequence". Track the longest prefix-as-suffix in the reversed view, + # capped at max_ngram. + tokens = origin_tokens[::-1] + lps = np.zeros(max_ngram, dtype=np.int32) + + longest_ngram = 0 + position = 0 + prev_lps = 0 + i = 1 + while i < total: + if tokens[prev_lps] == tokens[i]: + prev_lps += 1 + if prev_lps >= longest_ngram: + longest_ngram = prev_lps + position = i + if i < max_ngram: + lps[i] = prev_lps + if prev_lps == max_ngram: + prev_lps = lps[max_ngram - 1] + i += 1 + elif prev_lps != 0: + prev_lps = lps[prev_lps - 1] + else: + i += 1 + + if longest_ngram < min_ngram: + return np.empty((0,), dtype=origin_tokens.dtype) + + # In origin_tokens, the matched ngram lives at indices + # [total-1-position : total-1-position+longest_ngram]; drafts start + # right after it. + start = total - 1 - position + longest_ngram + take = min(k, total - start) + if take <= 0: + return np.empty((0,), dtype=origin_tokens.dtype) + return origin_tokens[start : start + take] + + +def propose_batch_into( + history: np.ndarray, + history_len: np.ndarray, + pool_indices: np.ndarray, + out: np.ndarray, + min_ngram: int, + max_ngram: int, + spec_num_steps: int, +) -> None: + """Fill ``out[i]`` with ``[last_verified, d1, ..., d_K]`` for each + request in the batch by running the KMP suffix-ngram lookup against + that request's history slot. + + ``out`` must have shape ``(bs, spec_num_steps + 1)`` and is written in + place. Slots whose ``history_len`` is zero are zeroed defensively; a + well-formed run never proposes before prefill). Slots with no match + are padded with ``last_verified`` to preserve the fixed verify width + without adding a dedicated no-match mask. + """ + bs = pool_indices.shape[0] + spec_num_tokens = spec_num_steps + 1 + if out.shape != (bs, spec_num_tokens): + raise ValueError( + f"out must be (bs, {spec_num_tokens})-shaped, got {out.shape!r}" + ) + + for i in range(bs): + pool_idx = int(pool_indices[i]) + length = int(history_len[pool_idx]) + if length == 0: + out[i, :] = 0 + continue + + ctx = history[pool_idx, :length] + last_verified = int(ctx[-1]) + out[i, 0] = last_verified + + drafts = find_longest_matched_ngram_and_propose_tokens( + ctx, + min_ngram=min_ngram, + max_ngram=max_ngram, + k=spec_num_steps, + ) + n = drafts.size + if n > 0: + out[i, 1 : 1 + n] = drafts + if 1 + n < spec_num_tokens: + out[i, 1 + n :] = last_verified diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index b224124fc..1773ee97f 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -36,6 +36,7 @@ from tokenspeed.runtime.execution.context import ForwardContext from tokenspeed.runtime.execution.cuda_graph_wrapper import CudaGraphWrapper from tokenspeed.runtime.execution.drafter.eagle import Eagle +from tokenspeed.runtime.execution.drafter.ngram import NgramDrafter from tokenspeed.runtime.execution.forward_batch_info import ( CaptureHiddenMode, ForwardMode, @@ -59,7 +60,7 @@ logger = get_colorful_logger(__name__) -_DRAFTER_MAPPING = {"EAGLE3": Eagle, "MTP": Eagle} +_DRAFTER_MAPPING = {"EAGLE3": Eagle, "MTP": Eagle, "NGRAM": NgramDrafter} @dataclass @@ -75,7 +76,7 @@ class ModelExecutorConfig: enforce_eager: bool block_size: int max_num_seqs: int - chunked_prefill_size: int + max_num_input_tokens: int vocab_size: int context_len: int device: str @@ -97,6 +98,9 @@ class ModelExecutorConfig: spec_num_steps: int | None = None # spec_num_tokens == spec_num_steps + 1 for now (without Tree Attention) spec_num_tokens: int | None = None + # NGRAM-only: bounds for KMP suffix-ngram lookup. + spec_ngram_min: int = 1 + spec_ngram_max: int = 3 # ====== GRAMMAR ========= # "none" disables all grammar handling; otherwise the backend name @@ -121,13 +125,18 @@ def from_server_args( if server_args.speculative_algorithm else 1 ) + max_num_input_tokens = ( + server_args.chunked_prefill_size + if server_args.chunked_prefill_size > 0 + else server_args.max_prefill_tokens + server_args.max_model_len + ) return ModelExecutorConfig( max_req_pool_size=max_req_pool_size, output_length=output_length, enforce_eager=server_args.enforce_eager, block_size=server_args.block_size, max_num_seqs=server_args.max_num_seqs, - chunked_prefill_size=server_args.chunked_prefill_size, + max_num_input_tokens=max_num_input_tokens, vocab_size=model_config.vocab_size, context_len=model_config.context_len, device=server_args.device, @@ -144,6 +153,8 @@ def from_server_args( spec_algo=server_args.speculative_algorithm, spec_num_steps=server_args.speculative_num_steps, spec_num_tokens=server_args.speculative_num_draft_tokens, + spec_ngram_min=server_args.speculative_ngram_min, + spec_ngram_max=server_args.speculative_ngram_max, grammar_backend=server_args.grammar_backend, disable_capturable_grammar=server_args.disable_capturable_grammar, mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, @@ -193,7 +204,7 @@ def __init__( spec_num_tokens = config.spec_num_tokens if config.spec_algo is not None else 1 self.input_buffers = InputBuffers( max_bs=config.max_num_seqs // max(config.data_parallel_size, 1), - max_num_tokens=config.chunked_prefill_size, + max_num_tokens=config.max_num_input_tokens, page_size=config.block_size, # token_to_kv_pool allocates size+page_size slots; index `size` is # the reserved dummy slot (see MHATokenToKVPool._create_buffers). @@ -211,24 +222,39 @@ def __init__( ) if self.config.spec_algo is not None: DrafterImpl = _DRAFTER_MAPPING[config.spec_algo] - self.drafter = DrafterImpl( - spec_num_tokens=config.spec_num_tokens, - spec_num_steps=config.spec_num_steps, - draft_model_runner=draft_model_runner, - page_size=config.block_size, - runtime_states=self.runtime_states, - input_buffers=self.input_buffers, - req_to_page=self.req_to_page, - attn_backend=draft_attn_backend, - token_to_kv_pool=draft_token_to_kv_pool, - vocab_size=config.vocab_size, - ) - embed, head = self.model_runner.model.get_embed_and_head() - draft_model_runner.model.set_embed_and_head(embed, head) - if config.spec_algo in ("EAGLE3",) and hasattr( - self.model_runner.model, "set_eagle3_layers_to_capture" - ): - self.model_runner.model.set_eagle3_layers_to_capture() + if config.spec_algo == "NGRAM": + # NGRAM is a prompt-lookup proposer with no draft model, + # no draft attention backend, and no draft KV pool. The + # executor must not have been wired with any of those. + self.drafter = DrafterImpl( + spec_num_tokens=config.spec_num_tokens, + spec_num_steps=config.spec_num_steps, + runtime_states=self.runtime_states, + input_buffers=self.input_buffers, + max_context_len=config.context_len, + vocab_size=config.vocab_size, + min_ngram=config.spec_ngram_min, + max_ngram=config.spec_ngram_max, + ) + else: + self.drafter = DrafterImpl( + spec_num_tokens=config.spec_num_tokens, + spec_num_steps=config.spec_num_steps, + draft_model_runner=draft_model_runner, + page_size=config.block_size, + runtime_states=self.runtime_states, + input_buffers=self.input_buffers, + req_to_page=self.req_to_page, + attn_backend=draft_attn_backend, + token_to_kv_pool=draft_token_to_kv_pool, + vocab_size=config.vocab_size, + ) + embed, head = self.model_runner.model.get_embed_and_head() + draft_model_runner.model.set_embed_and_head(embed, head) + if config.spec_algo in ("EAGLE3",) and hasattr( + self.model_runner.model, "set_eagle3_layers_to_capture" + ): + self.model_runner.model.set_eagle3_layers_to_capture() else: self.drafter = None @@ -694,10 +720,12 @@ def execute_idle_forward( input_lengths=empty, ) - # If a drafter is active, its model also has MoE layers that issue - # NCCL collectives. Idle ranks must match those collectives: - # 1 first-step forward + (spec_num_steps - 1) multi-step decode forwards. - if self.drafter is not None: + # If a drafter with a draft model is active, that model has MoE + # layers that issue NCCL collectives. Idle ranks must match those + # collectives: 1 first-step forward + (spec_num_steps - 1) multi-step + # decode forwards. NGRAM has no draft model and therefore nothing + # extra to match here. + if self.drafter is not None and self.drafter.draft_model_runner is not None: draft_ctx = ForwardContext( attn_backend=self.drafter.attn_backend, token_to_kv_pool=self.drafter.token_to_kv_pool, @@ -872,7 +900,10 @@ def execute_forward_op( forward_mode=forward_mode, capture_hidden_mode=( CaptureHiddenMode.FULL - if self.drafter is not None + if ( + self.drafter is not None + and self.config.spec_algo != "NGRAM" + ) else CaptureHiddenMode.NULL ), padded_static_len=-1, diff --git a/python/tokenspeed/runtime/spec_decode/algorithm.py b/python/tokenspeed/runtime/spec_decode/algorithm.py index 7f7b10ba5..36688e5b9 100755 --- a/python/tokenspeed/runtime/spec_decode/algorithm.py +++ b/python/tokenspeed/runtime/spec_decode/algorithm.py @@ -25,11 +25,13 @@ class SpeculativeAlgorithm(IntEnum): NONE = auto() EAGLE3 = auto() MTP = auto() + NGRAM = auto() def is_none(self) -> bool: return self == SpeculativeAlgorithm.NONE def needs_draft_decode_prealloc(self) -> bool: + # NGRAM has no draft model and therefore no draft KV slots to reserve. return self in (SpeculativeAlgorithm.EAGLE3, SpeculativeAlgorithm.MTP) @staticmethod @@ -37,6 +39,7 @@ def from_string(name: str | None) -> "SpeculativeAlgorithm": name_map = { "EAGLE3": SpeculativeAlgorithm.EAGLE3, "MTP": SpeculativeAlgorithm.MTP, + "NGRAM": SpeculativeAlgorithm.NGRAM, None: SpeculativeAlgorithm.NONE, } if name is not None: diff --git a/python/tokenspeed/runtime/utils/server_args.py b/python/tokenspeed/runtime/utils/server_args.py index 3b8aba51d..634c03adf 100755 --- a/python/tokenspeed/runtime/utils/server_args.py +++ b/python/tokenspeed/runtime/utils/server_args.py @@ -210,6 +210,9 @@ class ServerArgs: speculative_num_steps: int = 3 speculative_eagle_topk: int = 1 speculative_num_draft_tokens: int | None = None + # NGRAM-only: bounds for the KMP suffix-ngram lookup. + speculative_ngram_min: int = 1 + speculative_ngram_max: int = 3 eagle3_layers_to_capture: str | None = None # Logprob support flags — all OFF by default. Enabling extends the # captured CUDA-graph footprint; requests asking for logprobs on a @@ -486,6 +489,10 @@ def resolve_cache(self): self.validate_cache_options() def resolve_speculative_decoding(self): + if self.speculative_algorithm == "NGRAM": + self._resolve_ngram_speculative_decoding() + return + # Keep drafter backend consistent with the main model unless explicitly set. if ( self.speculative_algorithm is not None @@ -513,6 +520,88 @@ def resolve_speculative_decoding(self): int(x) for x in self.eagle3_layers_to_capture.split(",") ] + def _resolve_ngram_speculative_decoding(self): + """NGRAM is a prompt-lookup proposer with no draft model. + + It reuses the chain verify path, so ``speculative_num_steps`` and + ``speculative_num_draft_tokens`` keep the chain contract + (``draft_tokens == steps + 1``). Reject configurations that don't + make sense for a CPU-side lookup-only drafter to surface mistakes + at startup instead of mid-run. + """ + if self.speculative_draft_model_path is not None: + raise ValueError( + "--speculative-algorithm NGRAM does not use a draft model; " + "remove --speculative-draft-model-path." + ) + if self.speculative_eagle_topk != 1: + raise ValueError( + "--speculative-algorithm NGRAM only supports topk=1 (chain), " + f"got --speculative-eagle-topk={self.speculative_eagle_topk}." + ) + expected_draft_tokens = self.speculative_num_steps + 1 + if self.speculative_num_draft_tokens != expected_draft_tokens: + raise ValueError( + "--speculative-algorithm NGRAM requires " + "--speculative-num-draft-tokens == --speculative-num-steps + 1, " + f"got {self.speculative_num_draft_tokens} != " + f"{expected_draft_tokens}." + ) + if self.speculative_ngram_min < 1: + raise ValueError( + "--speculative-ngram-min must be >= 1, got " + f"{self.speculative_ngram_min}." + ) + if self.speculative_ngram_max < self.speculative_ngram_min: + raise ValueError( + "--speculative-ngram-max must be >= --speculative-ngram-min, got " + f"{self.speculative_ngram_max} < {self.speculative_ngram_min}." + ) + if self.disaggregation_mode != "null": + raise ValueError( + "--speculative-algorithm NGRAM is not yet supported under " + "prefill/decode disaggregation." + ) + + # The first NGRAM implementation keeps its own per-request CPU token + # history. Disable features that can make that history diverge from + # the committed request stream until the runtime exposes a reset / + # commit signal for those paths. + if self.enable_prefix_caching: + logger.warning( + "--speculative-algorithm NGRAM disables prefix caching in this " + "release." + ) + self.enable_prefix_caching = False + if not self.enable_prefix_caching and self.enable_kvstore: + logger.warning( + "--speculative-algorithm NGRAM disables KVStore because prefix " + "caching is disabled in this release." + ) + self.enable_kvstore = False + self.disable_kvstore = True + if self.chunked_prefill_size != -1: + logger.warning( + "--speculative-algorithm NGRAM disables chunked prefill in this " + "release." + ) + self.chunked_prefill_size = -1 + + # NGRAM's drafter runs on CPU and is not part of the captured + # graph in this first cut. Force eager so users don't silently + # hit graph-capture failures during warmup. + if not self.enforce_eager: + logger.warning( + "--speculative-algorithm NGRAM forces --enforce-eager in this " + "release; CUDA-graph capture will be enabled in a follow-up PR." + ) + self.enforce_eager = True + + # NGRAM has no drafter attention path. + self.drafter_attention_backend = None + self.speculative_draft_model_quantization = None + self.draft_model_path_use_base = False + def resolve_communication(self): # Auto-enable allreduce fusion on supported single-node TP configurations. platform = current_platform() @@ -1338,7 +1427,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE3", "MTP"], + choices=["EAGLE3", "MTP", "NGRAM"], help="Speculative algorithm.", ) parser.add_argument( @@ -1371,6 +1460,18 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The number of tokens sampled from the draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-ngram-min", + type=int, + default=ServerArgs.speculative_ngram_min, + help="Minimum n-gram length (NGRAM algorithm only).", + ) + parser.add_argument( + "--speculative-ngram-max", + type=int, + default=ServerArgs.speculative_ngram_max, + help="Maximum n-gram length (NGRAM algorithm only).", + ) parser.add_argument( "--enable-output-logprobs", action="store_true", diff --git a/test/runtime/test_cli_config_compat.py b/test/runtime/test_cli_config_compat.py index 3cc53e480..d3eebf329 100644 --- a/test/runtime/test_cli_config_compat.py +++ b/test/runtime/test_cli_config_compat.py @@ -475,6 +475,146 @@ def test_speculative_draft_tokens_default_to_steps_plus_one(self): self.assertEqual(sa.speculative_num_steps, 1) self.assertEqual(sa.speculative_num_draft_tokens, 2) + # N-gram (prompt-lookup) speculative decoding + + def _from_cli_args_for_ngram(self, argv: list[str]) -> ServerArgs: + """Parse argv and apply the basic + speculative resolvers used at + startup so NGRAM-specific validation runs. + + We deliberately avoid the full ``__post_init__`` so the test + does not pull in GPU/parallelism resolution. + """ + args = self._parse_args(argv) + sa = self._from_cli_args_no_init(args) + sa.resolve_basic_defaults() + sa.resolve_speculative_decoding() + return sa + + def test_ngram_via_speculative_algorithm(self): + sa = self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-algorithm", + "NGRAM", + "--speculative-num-steps", + "5", + "--speculative-ngram-min", + "2", + "--speculative-ngram-max", + "4", + ] + ) + self.assertEqual(sa.speculative_algorithm, "NGRAM") + self.assertEqual(sa.speculative_num_steps, 5) + self.assertEqual(sa.speculative_num_draft_tokens, 6) + self.assertEqual(sa.speculative_ngram_min, 2) + self.assertEqual(sa.speculative_ngram_max, 4) + self.assertIsNone(sa.speculative_draft_model_path) + self.assertIsNone(sa.drafter_attention_backend) + self.assertTrue(sa.enforce_eager) + self.assertFalse(sa.enable_prefix_caching) + self.assertEqual(sa.chunked_prefill_size, -1) + + def test_ngram_via_speculative_config(self): + sa = self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-config", + '{"method":"ngram","num_speculative_tokens":2}', + ] + ) + self.assertEqual(sa.speculative_algorithm, "NGRAM") + self.assertEqual(sa.speculative_num_steps, 2) + self.assertEqual(sa.speculative_num_draft_tokens, 3) + + def test_ngram_rejects_draft_model_path(self): + with self.assertRaisesRegex(ValueError, "NGRAM does not use a draft model"): + self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-algorithm", + "NGRAM", + "--speculative-draft-model-path", + "some/draft", + ] + ) + + def test_ngram_rejects_topk_other_than_one(self): + with self.assertRaisesRegex(ValueError, "NGRAM only supports topk=1"): + self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-algorithm", + "NGRAM", + "--speculative-eagle-topk", + "2", + ] + ) + + def test_ngram_rejects_mismatched_draft_token_width(self): + with self.assertRaisesRegex( + ValueError, "NGRAM requires .*num-draft-tokens" + ): + self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-algorithm", + "NGRAM", + "--speculative-num-steps", + "3", + "--speculative-num-draft-tokens", + "3", + ] + ) + + def test_ngram_rejects_invalid_ngram_bounds(self): + with self.assertRaisesRegex(ValueError, "--speculative-ngram-min must be >= 1"): + self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-algorithm", + "NGRAM", + "--speculative-ngram-min", + "0", + ] + ) + with self.assertRaisesRegex( + ValueError, "--speculative-ngram-max must be >= --speculative-ngram-min" + ): + self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-algorithm", + "NGRAM", + "--speculative-ngram-min", + "4", + "--speculative-ngram-max", + "2", + ] + ) + + def test_ngram_rejects_disaggregation_mode(self): + with self.assertRaisesRegex( + ValueError, "NGRAM is not yet supported under .* disaggregation" + ): + self._from_cli_args_for_ngram( + [ + "--model", + "test/model", + "--speculative-algorithm", + "NGRAM", + "--disaggregation-mode", + "prefill", + ] + ) + # ---- Full server command example ---- def test_full_server_command(self): diff --git a/test/runtime/test_spec_decode_ngram.py b/test/runtime/test_spec_decode_ngram.py new file mode 100644 index 000000000..43760696e --- /dev/null +++ b/test/runtime/test_spec_decode_ngram.py @@ -0,0 +1,187 @@ +"""Unit tests for the n-gram (prompt-lookup) speculative drafter. + +Covers two layers, both written against the dep-free +``runtime/execution/drafter/ngram_lookup`` module so they can run +without torch / tokenspeed_kernel: + +1. The pure-numpy KMP suffix-ngram lookup + ``find_longest_matched_ngram_and_propose_tokens``. +2. ``propose_batch_into``, the batched in-place row builder that the + ``NgramDrafter`` wrapper delegates to. +""" + +import os +import sys +import unittest + +import numpy as np + +# CI registration (AST-parsed, runtime no-op). +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from ci_system.ci_register import register_cuda_ci # noqa: E402 + +register_cuda_ci(est_time=10, suite="runtime-1gpu") + +from tokenspeed.runtime.execution.drafter.ngram_lookup import ( # noqa: E402 + find_longest_matched_ngram_and_propose_tokens, + propose_batch_into, +) + + +class TestFindLongestMatchedNgram(unittest.TestCase): + """Pure-function tests for the KMP suffix-ngram lookup.""" + + def _propose(self, tokens, *, min_n=1, max_n=3, k=2): + return find_longest_matched_ngram_and_propose_tokens( + np.asarray(tokens, dtype=np.int32), + min_ngram=min_n, + max_ngram=max_n, + k=k, + ).tolist() + + def test_returns_empty_when_context_shorter_than_min(self): + self.assertEqual(self._propose([], min_n=1, k=3), []) + self.assertEqual(self._propose([7], min_n=2, k=3), []) + + def test_returns_empty_when_k_is_nonpositive(self): + self.assertEqual(self._propose([1, 2, 1, 2], k=0), []) + + def test_returns_empty_when_no_ngram_repeats(self): + self.assertEqual(self._propose([10, 11, 12, 13], min_n=1, k=2), []) + + def test_picks_continuation_after_rightmost_match(self): + # Sequence repeats "1, 2": the most recent match suffix is the + # final "1, 2"; drafts come from right after the earlier "1, 2". + tokens = [1, 2, 9, 1, 2] + drafts = self._propose(tokens, min_n=2, max_n=2, k=2) + self.assertEqual(drafts, [9, 1]) + + def test_prefers_longest_match_within_bounds(self): + # Two competing matches: "B C" (len 2) and "A B C" (len 3). With + # max_ngram >= 3 the longer match wins. + tokens = [5, 5, 5, 0, 1, 2, 9, 9, 0, 1, 2] + drafts = self._propose(tokens, min_n=1, max_n=3, k=2) + self.assertEqual(drafts, [9, 9]) + + def test_caps_at_max_ngram(self): + # Suffix "0 1 2" *would* match as a len-3 ngram. With max_ngram=2 + # we force the shorter match and the draft start position moves. + tokens = [0, 1, 2, 9, 0, 1, 2] + drafts = self._propose(tokens, min_n=1, max_n=2, k=2) + self.assertEqual(drafts, [9, 0]) + + def test_returns_fewer_than_k_when_context_exhausted(self): + # The earlier "1 2" sits at indices 0..1; tokens 2..4 follow it, + # so with k=10 we still only get those 3 tokens. + tokens = [1, 2, 9, 1, 2] + drafts = self._propose(tokens, min_n=2, max_n=2, k=10) + self.assertEqual(drafts, [9, 1, 2]) + + +class TestProposeBatchInto(unittest.TestCase): + """Verify the row layout written by the batched proposer.""" + + def _new_batch(self, *, max_bs=4, max_context_len=128, spec_num_steps=3): + history = np.zeros((max_bs, max_context_len), dtype=np.int32) + history_len = np.zeros((max_bs,), dtype=np.int32) + out = np.zeros((max_bs, spec_num_steps + 1), dtype=np.int32) + return history, history_len, out + + def _seed(self, history, history_len, slot, tokens): + tokens = np.asarray(tokens, dtype=np.int32) + history[slot, : tokens.size] = tokens + history_len[slot] = tokens.size + + def test_layout_for_matching_history(self): + history, history_len, out = self._new_batch(spec_num_steps=3) + self._seed(history, history_len, slot=0, tokens=[1, 2, 3, 1, 2]) + self._seed(history, history_len, slot=1, tokens=[9, 8, 7]) + + pool_indices = np.array([0, 1], dtype=np.int32) + propose_batch_into( + history=history, + history_len=history_len, + pool_indices=pool_indices, + out=out[: pool_indices.size], + min_ngram=1, + max_ngram=3, + spec_num_steps=3, + ) + + row0 = out[0].tolist() + row1 = out[1].tolist() + # First column is always last_verified (= last history token). + self.assertEqual(row0[0], 2) + self.assertEqual(row1[0], 7) + # Slot 0: KMP picks the rightmost "1 2"; continuation starts with 3. + self.assertEqual(row0[1], 3) + # Slot 1: no repeat. Draft columns fall back to ``last_verified`` + # to preserve the fixed verify width without a no-match mask. + self.assertEqual(row1[1:], [7, 7, 7]) + + def test_pads_remaining_columns_with_last_verified(self): + history, history_len, out = self._new_batch(spec_num_steps=4) + self._seed(history, history_len, slot=0, tokens=[4, 5, 6, 4, 5]) + + pool_indices = np.array([0], dtype=np.int32) + propose_batch_into( + history=history, + history_len=history_len, + pool_indices=pool_indices, + out=out[: pool_indices.size], + min_ngram=1, + max_ngram=3, + spec_num_steps=4, + ) + + # Match suffix "4 5"; continuation = [6, 4, 5]. Trailing column + # padded with last_verified=5. + self.assertEqual(out[0].tolist(), [5, 6, 4, 5, 5]) + + def test_zero_history_row_emits_zeros(self): + history, history_len, out = self._new_batch(spec_num_steps=2) + pool_indices = np.array([0], dtype=np.int32) + propose_batch_into( + history=history, + history_len=history_len, + pool_indices=pool_indices, + out=out[: pool_indices.size], + min_ngram=1, + max_ngram=3, + spec_num_steps=2, + ) + self.assertEqual(out[0].tolist(), [0, 0, 0]) + + def test_rejects_wrong_out_shape(self): + history, history_len, _ = self._new_batch(spec_num_steps=3) + pool_indices = np.array([0], dtype=np.int32) + bad_out = np.zeros((1, 5), dtype=np.int32) # spec_num_steps+1 should be 4 + with self.assertRaises(ValueError): + propose_batch_into( + history=history, + history_len=history_len, + pool_indices=pool_indices, + out=bad_out, + min_ngram=1, + max_ngram=3, + spec_num_steps=3, + ) + + def test_rejects_wrong_out_batch_size(self): + history, history_len, _ = self._new_batch(spec_num_steps=3) + pool_indices = np.array([0, 1], dtype=np.int32) + bad_out = np.zeros((1, 4), dtype=np.int32) + with self.assertRaises(ValueError): + propose_batch_into( + history=history, + history_len=history_len, + pool_indices=pool_indices, + out=bad_out, + min_ngram=1, + max_ngram=3, + spec_num_steps=3, + ) + + +if __name__ == "__main__": + unittest.main() From fb611a5acef06149823487d2be5d7d616b7bf22d Mon Sep 17 00:00:00 2001 From: yongjunlee Date: Thu, 14 May 2026 10:35:11 +0000 Subject: [PATCH 2/2] fix(spec-decode): preserve chunked-prefill budget for NGRAM NGRAM previously forced `chunked_prefill_size = -1` to disable chunked prefill. That value is also used as a scheduler-facing token capacity, so propagating the negative sentinel can prevent SMG startup from completing. Leave the resolved chunked-prefill budget untouched for NGRAM. The drafter's per-pool token history is updated from the actual extend/verify stream, so it does not require chunked prefill to be disabled. Tests: - `python3 -m pytest test/runtime/test_cli_config_compat.py -q` - `python3 -m pytest test/runtime/test_spec_decode_ngram.py -q` Signed-off-by: yongjunlee --- docs/configuration/server.md | 6 +++--- python/tokenspeed/runtime/utils/server_args.py | 7 ------- test/runtime/test_cli_config_compat.py | 1 - 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/docs/configuration/server.md b/docs/configuration/server.md index 8b7bcf70a..fa77b50ff 100644 --- a/docs/configuration/server.md +++ b/docs/configuration/server.md @@ -141,9 +141,9 @@ The first release is intentionally narrow: - Single-rank only; running under PD disaggregation is rejected at startup. - The drafter runs outside the captured CUDA graph, so `--enforce-eager` is auto-enabled with a warning. -- `enable_prefix_caching`, `enable_kvstore`, and chunked prefill are - auto-disabled because the proposer keeps its own per-request token history - and prefix-cache hits would skip the prefill path it relies on. +- `enable_prefix_caching` and `enable_kvstore` are auto-disabled because the + proposer keeps its own per-request token history and prefix-cache hits would + skip the prefill path it relies on. - `--speculative-eagle-topk` must be `1` (chain only), and `--speculative-num-draft-tokens` must equal `--speculative-num-steps + 1`. diff --git a/python/tokenspeed/runtime/utils/server_args.py b/python/tokenspeed/runtime/utils/server_args.py index 634c03adf..e6f151f3d 100755 --- a/python/tokenspeed/runtime/utils/server_args.py +++ b/python/tokenspeed/runtime/utils/server_args.py @@ -580,13 +580,6 @@ def _resolve_ngram_speculative_decoding(self): ) self.enable_kvstore = False self.disable_kvstore = True - if self.chunked_prefill_size != -1: - logger.warning( - "--speculative-algorithm NGRAM disables chunked prefill in this " - "release." - ) - self.chunked_prefill_size = -1 - # NGRAM's drafter runs on CPU and is not part of the captured # graph in this first cut. Force eager so users don't silently # hit graph-capture failures during warmup. diff --git a/test/runtime/test_cli_config_compat.py b/test/runtime/test_cli_config_compat.py index d3eebf329..f633b5ab5 100644 --- a/test/runtime/test_cli_config_compat.py +++ b/test/runtime/test_cli_config_compat.py @@ -514,7 +514,6 @@ def test_ngram_via_speculative_algorithm(self): self.assertIsNone(sa.drafter_attention_backend) self.assertTrue(sa.enforce_eager) self.assertFalse(sa.enable_prefix_caching) - self.assertEqual(sa.chunked_prefill_size, -1) def test_ngram_via_speculative_config(self): sa = self._from_cli_args_for_ngram(