Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions docs/recipes/speculative_decoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Speculative Decoding Recipes

The commands below are templates. Validate exact model IDs, checkpoint formats,
and backend choices against the build you deploy.

## Llama 3.1 8B

```bash
tokenspeed serve nreHieW/Llama-3.1-8B-Instruct \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
--speculative-num-steps 7 \
--host 0.0.0.0 \
--dtype bfloat16 \
--kvstore-size 16 \
--port 8000
```

## GPT-OSS 20B / 120B

```bash
tokenspeed serve openai/gpt-oss-20b \
--served-model-name gpt-oss-20b \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path Dogacel/specdrift-gpt-oss-20b-eagle3 \
--speculative-num-steps 3 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--chunked-prefill-size 8192 \
--reasoning-parser base \
--host 0.0.0.0 \
--port 8000
```

```bash
tokenspeed serve openai/gpt-oss-120b \
--served-model-name gpt-oss-120b \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path nvidia/gpt-oss-120b-Eagle3-long-context \
--speculative-num-steps 3 \
--tensor-parallel-size 1 \
--max-model-len 8192 \
--kv-cache-dtype fp8 \
--chunked-prefill-size 8192 \
--max-num-seqs 4 \
--reasoning-parser base \
--host 0.0.0.0 \
--port 8000
```

## Benchmarking

Against the GPT-OSS 120B server above:

```bash
tokenspeed bench serve \
--backend openai-chat \
--endpoint /v1/chat/completions \
--host 127.0.0.1 --port 8000 \
--dataset-name mtbench \
--input-len 1024 \
--output-len 1024 \
--num-prompts 80 \
--max-concurrency 16 \
--save-result --save-detailed --result-dir results/ \
--extra-body '{"temperature": 0}'
```

Make sure you choose `openai-chat` over `openai`, otherwise the missing chat template causes acceptance
rates to go down abnormally. You can inspect the incoming requests & chat template via passing `--enable-log-requests --log-requests-level 2`
to the server.

The result block ends with something like:

```
Total token throughput (tok/s): 1781.94
Mean accept length (tok/step): 3.25
```

A value near `1.0` means almost no draft tokens are being accepted, check the draft model config.
82 changes: 81 additions & 1 deletion python/tokenspeed/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
DEFAULT_NUM_PROMPTS = 1000
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
MTBENCH_URL = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
OPENAI_COMPATIBLE_BACKENDS = frozenset({"openai", "tokenspeed"})
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -262,6 +263,7 @@ class BenchmarkMetrics:
percentiles_e2el_ms: list[tuple[float, float]]
max_output_tokens_per_s: float
max_concurrent_requests: int
mean_accept_length: float


def set_ulimit(target_soft_limit: int = 65535) -> None:
Expand Down Expand Up @@ -1049,6 +1051,58 @@ def sample_sharegpt_requests(
return samples


def sample_mtbench_requests(
dataset_path: str | None,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: int | None = None,
max_model_len: int | None = None,
apply_chat_template: bool = False,
skip_min_tokens_check: bool = False,
) -> list[SampleRequest]:
if not dataset_path:
dataset_path = download_and_cache_file(MTBENCH_URL)

questions: list[str] = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
entry = json.loads(line)
turns = entry.get("turns", [])
if turns:
questions.append(turns[0])
random.shuffle(questions)

samples: list[SampleRequest] = []
valid: list[SampleRequest] = []
for prompt in questions:
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
if tokenizer.bos_token:
prompt = prompt.replace(tokenizer.bos_token, "")
prompt_len = len(tokenizer.encode(prompt))
output_len = fixed_output_len if fixed_output_len is not None else 256
if not is_valid_sequence(
prompt_len, output_len, max_model_len, skip_min_tokens_check
):
continue
valid.append(SampleRequest(prompt, prompt_len, output_len))
for i in range(num_requests):
if not valid:
break
samples.append(valid[i % len(valid)])

print(f"#Input tokens: {sum(x.prompt_len for x in samples)}")
print(f"#Output tokens: {sum(x.expected_output_len for x in samples)}")
return samples


def sample_random_requests(
input_len: int,
output_len: int,
Expand Down Expand Up @@ -1091,6 +1145,16 @@ def get_samples(
apply_chat_template=args.apply_chat_template,
skip_min_tokens_check=args.skip_min_tokens_check,
)
if args.dataset_name == "mtbench":
return sample_mtbench_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.mtbench_output_len,
max_model_len=args.max_model_len,
apply_chat_template=args.apply_chat_template,
skip_min_tokens_check=args.skip_min_tokens_check,
)
if args.dataset_name == "random":
return sample_random_requests(
input_len=args.random_input_len,
Expand Down Expand Up @@ -1253,6 +1317,8 @@ def calculate_metrics(
all_tpots: list[float] = []
ttfts: list[float] = []
e2els: list[float] = []
total_output_tokens_for_accept = 0
total_chunks_for_accept = 0

for output in outputs:
if output.success:
Expand All @@ -1277,6 +1343,9 @@ def calculate_metrics(
itls.extend(output.itl)
ttfts.append(output.ttft)
e2els.append(output.latency)
if output_len > 0 and output.itl:
total_output_tokens_for_accept += output_len
total_chunks_for_accept += len(output.itl)
completed += 1
else:
actual_output_lens.append(0)
Expand Down Expand Up @@ -1386,6 +1455,11 @@ def calculate_metrics(
],
max_output_tokens_per_s=max_output_tokens_per_s,
max_concurrent_requests=max_concurrent_requests,
mean_accept_length=(
total_output_tokens_for_accept / total_chunks_for_accept
if total_chunks_for_accept > 0
else 1.0
),
)
return metrics, actual_output_lens

Expand Down Expand Up @@ -1619,6 +1693,9 @@ async def limited_request_func(request_func_input, session, pbar):
metrics.total_token_throughput,
precision=2,
)
_print_metric_row(
"Mean accept length (tok/step):", metrics.mean_accept_length, precision=2
)

result: dict[str, Any] = {
"duration": benchmark_duration,
Expand All @@ -1630,6 +1707,7 @@ async def limited_request_func(request_func_input, session, pbar):
"request_goodput": metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"mean_accept_length": metrics.mean_accept_length,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
Expand Down Expand Up @@ -1749,7 +1827,7 @@ def add_dataset_parser(parser: argparse.ArgumentParser) -> None:
"--dataset-name",
type=str,
default="random",
choices=["sharegpt", "random"],
choices=["sharegpt", "random", "mtbench"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
Expand All @@ -1761,6 +1839,7 @@ def add_dataset_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--max-model-len", type=int, default=None)
parser.add_argument("--skip-min-tokens-check", action="store_true")
parser.add_argument("--sharegpt-output-len", type=int, default=None)
parser.add_argument("--mtbench-output-len", type=int, default=None)
parser.add_argument("--random-input-len", type=int, default=1024)
parser.add_argument("--random-output-len", type=int, default=128)
parser.add_argument("--random-range-ratio", type=float, default=0.0)
Expand Down Expand Up @@ -1840,6 +1919,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if args.output_len is not None:
args.random_output_len = args.output_len
args.sharegpt_output_len = args.output_len
args.mtbench_output_len = args.output_len

if args.ramp_up_strategy is not None:
if args.request_rate != float("inf"):
Expand Down
10 changes: 9 additions & 1 deletion python/tokenspeed/runtime/execution/model_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,15 @@ def __init__(
if config.spec_algo in ("EAGLE3",) and hasattr(
self.model_runner.model, "set_eagle3_layers_to_capture"
):
self.model_runner.model.set_eagle3_layers_to_capture()
draft_layer_ids = None
draft_model = getattr(draft_model_runner, "model", None)
if (
draft_model is not None
and hasattr(draft_model, "model")
and hasattr(draft_model.model, "get_aux_hidden_state_layer_ids")
):
draft_layer_ids = draft_model.model.get_aux_hidden_state_layer_ids()
self.model_runner.model.set_eagle3_layers_to_capture(draft_layer_ids)
else:
self.drafter = None

Expand Down
Loading