diff --git a/docs/recipes/speculative_decoding.md b/docs/recipes/speculative_decoding.md new file mode 100644 index 000000000..03d82b31a --- /dev/null +++ b/docs/recipes/speculative_decoding.md @@ -0,0 +1,80 @@ +# Speculative Decoding Recipes + +The commands below are templates. Validate exact model IDs, checkpoint formats, +and backend choices against the build you deploy. + +## Llama 3.1 8B + +```bash +tokenspeed serve nreHieW/Llama-3.1-8B-Instruct \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \ + --speculative-num-steps 7 \ + --host 0.0.0.0 \ + --dtype bfloat16 \ + --kvstore-size 16 \ + --port 8000 +``` + +## GPT-OSS 20B / 120B + +```bash +tokenspeed serve openai/gpt-oss-20b \ + --served-model-name gpt-oss-20b \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path Dogacel/specdrift-gpt-oss-20b-eagle3 \ + --speculative-num-steps 3 \ + --tensor-parallel-size 1 \ + --max-model-len 131072 \ + --chunked-prefill-size 8192 \ + --reasoning-parser base \ + --host 0.0.0.0 \ + --port 8000 +``` + +```bash +tokenspeed serve openai/gpt-oss-120b \ + --served-model-name gpt-oss-120b \ + --speculative-algorithm EAGLE3 \ + --speculative-draft-model-path nvidia/gpt-oss-120b-Eagle3-long-context \ + --speculative-num-steps 3 \ + --tensor-parallel-size 1 \ + --max-model-len 8192 \ + --kv-cache-dtype fp8 \ + --chunked-prefill-size 8192 \ + --max-num-seqs 4 \ + --reasoning-parser base \ + --host 0.0.0.0 \ + --port 8000 +``` + +## Benchmarking + +Against the GPT-OSS 120B server above: + +```bash +tokenspeed bench serve \ + --backend openai-chat \ + --endpoint /v1/chat/completions \ + --host 127.0.0.1 --port 8000 \ + --dataset-name mtbench \ + --input-len 1024 \ + --output-len 1024 \ + --num-prompts 80 \ + --max-concurrency 16 \ + --save-result --save-detailed --result-dir results/ \ + --extra-body '{"temperature": 0}' +``` + +Make sure you choose `openai-chat` over `openai`, otherwise the missing chat template causes acceptance +rates to go down abnormally. You can inspect the incoming requests & chat template via passing `--enable-log-requests --log-requests-level 2` +to the server. + +The result block ends with something like: + +``` +Total token throughput (tok/s): 1781.94 +Mean accept length (tok/step): 3.25 +``` + +A value near `1.0` means almost no draft tokens are being accepted, check the draft model config. diff --git a/python/tokenspeed/bench.py b/python/tokenspeed/bench.py index c9a61f3ec..06d67eb79 100755 --- a/python/tokenspeed/bench.py +++ b/python/tokenspeed/bench.py @@ -102,6 +102,7 @@ DEFAULT_NUM_PROMPTS = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000 SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" +MTBENCH_URL = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" OPENAI_COMPATIBLE_BACKENDS = frozenset({"openai", "tokenspeed"}) logger = logging.getLogger(__name__) @@ -262,6 +263,7 @@ class BenchmarkMetrics: percentiles_e2el_ms: list[tuple[float, float]] max_output_tokens_per_s: float max_concurrent_requests: int + mean_accept_length: float def set_ulimit(target_soft_limit: int = 65535) -> None: @@ -1049,6 +1051,58 @@ def sample_sharegpt_requests( return samples +def sample_mtbench_requests( + dataset_path: str | None, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: int | None = None, + max_model_len: int | None = None, + apply_chat_template: bool = False, + skip_min_tokens_check: bool = False, +) -> list[SampleRequest]: + if not dataset_path: + dataset_path = download_and_cache_file(MTBENCH_URL) + + questions: list[str] = [] + with open(dataset_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + entry = json.loads(line) + turns = entry.get("turns", []) + if turns: + questions.append(turns[0]) + random.shuffle(questions) + + samples: list[SampleRequest] = [] + valid: list[SampleRequest] = [] + for prompt in questions: + if apply_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + if tokenizer.bos_token: + prompt = prompt.replace(tokenizer.bos_token, "") + prompt_len = len(tokenizer.encode(prompt)) + output_len = fixed_output_len if fixed_output_len is not None else 256 + if not is_valid_sequence( + prompt_len, output_len, max_model_len, skip_min_tokens_check + ): + continue + valid.append(SampleRequest(prompt, prompt_len, output_len)) + for i in range(num_requests): + if not valid: + break + samples.append(valid[i % len(valid)]) + + print(f"#Input tokens: {sum(x.prompt_len for x in samples)}") + print(f"#Output tokens: {sum(x.expected_output_len for x in samples)}") + return samples + + def sample_random_requests( input_len: int, output_len: int, @@ -1091,6 +1145,16 @@ def get_samples( apply_chat_template=args.apply_chat_template, skip_min_tokens_check=args.skip_min_tokens_check, ) + if args.dataset_name == "mtbench": + return sample_mtbench_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.mtbench_output_len, + max_model_len=args.max_model_len, + apply_chat_template=args.apply_chat_template, + skip_min_tokens_check=args.skip_min_tokens_check, + ) if args.dataset_name == "random": return sample_random_requests( input_len=args.random_input_len, @@ -1253,6 +1317,8 @@ def calculate_metrics( all_tpots: list[float] = [] ttfts: list[float] = [] e2els: list[float] = [] + total_output_tokens_for_accept = 0 + total_chunks_for_accept = 0 for output in outputs: if output.success: @@ -1277,6 +1343,9 @@ def calculate_metrics( itls.extend(output.itl) ttfts.append(output.ttft) e2els.append(output.latency) + if output_len > 0 and output.itl: + total_output_tokens_for_accept += output_len + total_chunks_for_accept += len(output.itl) completed += 1 else: actual_output_lens.append(0) @@ -1386,6 +1455,11 @@ def calculate_metrics( ], max_output_tokens_per_s=max_output_tokens_per_s, max_concurrent_requests=max_concurrent_requests, + mean_accept_length=( + total_output_tokens_for_accept / total_chunks_for_accept + if total_chunks_for_accept > 0 + else 1.0 + ), ) return metrics, actual_output_lens @@ -1619,6 +1693,9 @@ async def limited_request_func(request_func_input, session, pbar): metrics.total_token_throughput, precision=2, ) + _print_metric_row( + "Mean accept length (tok/step):", metrics.mean_accept_length, precision=2 + ) result: dict[str, Any] = { "duration": benchmark_duration, @@ -1630,6 +1707,7 @@ async def limited_request_func(request_func_input, session, pbar): "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, + "mean_accept_length": metrics.mean_accept_length, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], @@ -1749,7 +1827,7 @@ def add_dataset_parser(parser: argparse.ArgumentParser) -> None: "--dataset-name", type=str, default="random", - choices=["sharegpt", "random"], + choices=["sharegpt", "random", "mtbench"], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -1761,6 +1839,7 @@ def add_dataset_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument("--max-model-len", type=int, default=None) parser.add_argument("--skip-min-tokens-check", action="store_true") parser.add_argument("--sharegpt-output-len", type=int, default=None) + parser.add_argument("--mtbench-output-len", type=int, default=None) parser.add_argument("--random-input-len", type=int, default=1024) parser.add_argument("--random-output-len", type=int, default=128) parser.add_argument("--random-range-ratio", type=float, default=0.0) @@ -1840,6 +1919,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.output_len is not None: args.random_output_len = args.output_len args.sharegpt_output_len = args.output_len + args.mtbench_output_len = args.output_len if args.ramp_up_strategy is not None: if args.request_rate != float("inf"): diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index 2f1c61d12..fb8dd086a 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -228,7 +228,15 @@ def __init__( if config.spec_algo in ("EAGLE3",) and hasattr( self.model_runner.model, "set_eagle3_layers_to_capture" ): - self.model_runner.model.set_eagle3_layers_to_capture() + draft_layer_ids = None + draft_model = getattr(draft_model_runner, "model", None) + if ( + draft_model is not None + and hasattr(draft_model, "model") + and hasattr(draft_model.model, "get_aux_hidden_state_layer_ids") + ): + draft_layer_ids = draft_model.model.get_aux_hidden_state_layer_ids() + self.model_runner.model.set_eagle3_layers_to_capture(draft_layer_ids) else: self.drafter = None diff --git a/python/tokenspeed/runtime/models/llama_eagle3.py b/python/tokenspeed/runtime/models/llama_eagle3.py index 8e1e9ab13..76fdf72d7 100644 --- a/python/tokenspeed/runtime/models/llama_eagle3.py +++ b/python/tokenspeed/runtime/models/llama_eagle3.py @@ -471,16 +471,40 @@ def __init__( self.midlayer = self.layers[0] del self.layers - self.num_fc_input_dim = ( - len(config.eagle_aux_hidden_state_layer_ids) - if hasattr(config, "eagle_aux_hidden_state_layer_ids") - else 3 - ) + # Target-model hidden width; falls back to draft hidden_size when the + # draft and target share dimensions. + self.hidden_size_in = getattr(config, "target_hidden_size", config.hidden_size) + + self.num_fc_input_dim = getattr(config, "num_aux_hidden_states", None) + if self.num_fc_input_dim is None: + layer_ids = self.get_aux_hidden_state_layer_ids() + self.num_fc_input_dim = len(layer_ids) if layer_ids else 3 self.fc = torch.nn.Linear( - config.hidden_size * self.num_fc_input_dim, config.hidden_size + self.hidden_size_in * self.num_fc_input_dim, + config.hidden_size, + bias=getattr(config, "bias", False), ) + # Per-aux RMSNorm applied chunk-wise before `fc`. + if getattr(config, "fc_norm", None): + self.fc_norm = nn.ModuleList( + [ + RMSNorm(self.hidden_size_in, eps=config.rms_norm_eps) + for _ in range(self.num_fc_input_dim) + ] + ) + else: + self.fc_norm = None + + # When True, drafter consumes its own hidden states after the final norm. + self.norm_output = getattr(config, "norm_output", False) + + def get_aux_hidden_state_layer_ids(self): + eagle_config = getattr(self.config, "eagle_config", None) or {} + layer_ids = eagle_config.get("eagle_aux_hidden_state_layer_ids", None) + return layer_ids + def forward( self, input_ids: torch.Tensor, @@ -509,6 +533,12 @@ def forward( if hidden_states is None: raise ValueError("Eagle3 forward requires hidden_states") if hidden_states.shape[-1] != embeds.shape[-1]: + if self.fc_norm is not None: + chunks = hidden_states.chunk(self.num_fc_input_dim, dim=-1) + hidden_states = torch.cat( + [norm(chunk) for norm, chunk in zip(self.fc_norm, chunks)], + dim=-1, + ) hidden_states = self.fc(hidden_states) residual = None @@ -537,7 +567,8 @@ def forward( hidden_states_to_aux, None, ctx ) - return hidden_states_to_logits, [hidden_states_to_aux] + aux = hidden_states_to_logits if self.norm_output else hidden_states_to_aux + return hidden_states_to_logits, [aux] class LlamaForCausalLMEagle3(BaseCausalLM): @@ -592,11 +623,11 @@ def prepare_model_kwargs( else: # During CUDA graph capture warmup, provide dummy hidden states. num_tokens = input_ids.shape[0] - hidden_size = self.config.hidden_size + hidden_size_in = self.model.hidden_size_in num_fc = self.model.num_fc_input_dim model_kwargs["hidden_states"] = torch.zeros( num_tokens, - hidden_size * num_fc, + hidden_size_in * num_fc, dtype=torch.bfloat16, device=input_ids.device, ) @@ -613,7 +644,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: (".gate_up_proj", ".up_proj", 1), ] - for name, loaded_weight in weights: + # Map incoming checkpoint names to current module attribute names. + legacy_name_map = { + "layers.0": "midlayer", + } + + loaded_param_names: set[str] = set() + unmatched_checkpoint_keys: list[str] = [] + self.hot_token_id = None + + for original_name, loaded_weight in weights: + name = original_name + for legacy, new in legacy_name_map.items(): + if legacy in name: + name = name.replace(legacy, new) + if "d2t" in name: self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0]) continue @@ -621,26 +666,64 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: if "t2d" in name: continue + matched = False for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) - param_name = f"model.{name}" if name not in params_dict else name - if param_name in params_dict: - param = params_dict[param_name] + resolved = f"model.{name}" if name not in params_dict else name + if resolved in params_dict: + param = params_dict[resolved] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight, shard_id) + loaded_param_names.add(resolved) + matched = True break else: - param_name = name if name in params_dict else f"model.{name}" - if param_name in params_dict: - param = params_dict[param_name] + resolved = name if name in params_dict else f"model.{name}" + if resolved in params_dict: + param = params_dict[resolved] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) + loaded_param_names.add(resolved) + matched = True + + if not matched: + unmatched_checkpoint_keys.append(original_name) + + # ``embed_tokens`` and ``lm_head`` are overwritten later by + # ``set_embed_and_head`` when draft/target hidden sizes match, so + # they're expected to be missing from the EAGLE checkpoint. + expected_missing = {"model.embed_tokens.weight", "lm_head.weight"} + missing_param_names = sorted( + set(params_dict) - loaded_param_names - expected_missing + ) + + if unmatched_checkpoint_keys: + logger.warning( + "EAGLE3 load_weights: %d checkpoint key(s) did not match any " + "module parameter and were dropped: %s", + len(unmatched_checkpoint_keys), + unmatched_checkpoint_keys, + ) + if missing_param_names: + logger.warning( + "EAGLE3 load_weights: %d module parameter(s) were not " + "populated from the checkpoint and remain at initial values: " + "%s", + len(missing_param_names), + missing_param_names, + ) + if self.hot_token_id is None: + logger.warning( + "EAGLE3 load_weights: no 'd2t' tensor found in the checkpoint; " + "draft-to-target vocab mapping will be missing and acceptance " + "will be incorrect." + ) def get_hot_token_id(self): return self.hot_token_id diff --git a/tokenspeed-kernel/python/setup.py b/tokenspeed-kernel/python/setup.py index 66c533134..75256d38f 100644 --- a/tokenspeed-kernel/python/setup.py +++ b/tokenspeed-kernel/python/setup.py @@ -453,6 +453,30 @@ def _detect_cuda_archs(self): archs.add(self._normalize_cuda_arch(direct)) return archs + # Try detecting from the NVIDIA driver. + if not archs: + try: + caps = ( + subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=compute_cap", + "--format=csv,noheader", + ], + text=True, + stderr=subprocess.DEVNULL, + ) + .strip() + .splitlines() + ) + for cap in caps: + cap = cap.strip() + if cap: + archs.add(self._normalize_cuda_arch(cap + "a")) + except (OSError, subprocess.CalledProcessError): + pass + + # Fallback: Blackwell if not archs: archs.add("100a") return archs