diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index f6c4e33e0..0a2599386 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -49,12 +49,14 @@ async def process_single_request(request, i): if request.system_prompt is not None: messages.append({"role": "system", "content": request.system_prompt}) - for question in request.turns: + for turn_id, question in enumerate(request.turns): messages.append({"role": "user", "content": question}) entry_encoded = encode_chat(tokenizer, messages) # Run the async runner.run directly - output_tokens = await runner.run(entry_encoded, max_length, end_id, i) + output_tokens = await runner.run( + entry_encoded, max_length, end_id, request_id=i, turn_id=turn_id + ) output_text = decode_chat(tokenizer, output_tokens["output_ids"][0]) output_text = postprocess(output_text) messages.append({"role": "assistant", "content": output_text}) diff --git a/examples/specdec_bench/specdec_bench/metrics/aa_timing.py b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py index a4084a9c3..21af35112 100644 --- a/examples/specdec_bench/specdec_bench/metrics/aa_timing.py +++ b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py @@ -34,7 +34,7 @@ def __init__(self, base_tokenizer): self.base_tokenizer = base_tokenizer self.total_tokens = [] - def process_step(self, step_outputs, new_turn=True): + def process_step(self, step_outputs, request_id, turn_id): self.timing.append(step_outputs["token_times"]) target_tokens = [ t for tok_list in step_outputs["output_ids"] for tok in tok_list for t in tok diff --git a/examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py b/examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py index 22f10091a..fffd9ebf7 100644 --- a/examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py +++ b/examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py @@ -22,15 +22,17 @@ class AcceptanceRate(Metric): def __init__(self): super().__init__() - self.prompt_ar = [] + self.prompt_ar = {} self.name = "acceptance_rate" - def process_step(self, step_outputs, new_turn=True): - if new_turn: - self.prompt_ar.append([]) + def process_step(self, step_outputs, request_id, turn_id): + if request_id not in self.prompt_ar: + self.prompt_ar[request_id] = {} + if turn_id not in self.prompt_ar[request_id]: + self.prompt_ar[request_id][turn_id] = [] for i, beam_output in enumerate(step_outputs["output_ids"]): for output_id_iter in beam_output: - self.prompt_ar[-1].append(len(output_id_iter)) + self.prompt_ar[request_id][turn_id].append(len(output_id_iter)) def _get_lengths(self, turn, lengths): for j in turn: @@ -55,16 +57,19 @@ def _process_lengths(self, lengths): running_len -= v def process_final(self, text_outputs): - i = 0 + all_ar = [] lengths = {} self.out["Request_AR"] = {} - while i < len(self.prompt_ar): - turn_1 = self.prompt_ar[i] - self.out["Request_AR"][i] = sum(turn_1) / len(turn_1) - self._get_lengths(turn_1, lengths) - print(i, self.out["Request_AR"][i]) - i += 1 - average_ar = sum(self.out["Request_AR"].values()) / len(self.out["Request_AR"]) + self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0])) + for request_id, turns in self.prompt_ar.items(): + self.out["Request_AR"][request_id] = {} + for turn_id, turn in turns.items(): + ar = sum(turn) / len(turn) + self.out["Request_AR"][request_id][turn_id] = ar + all_ar.append(ar) + self._get_lengths(turn, lengths) + print(request_id, turn_id, self.out["Request_AR"][request_id][turn_id]) + average_ar = sum(all_ar) / len(all_ar) print("Average AR:", average_ar) self.out["Average_AR"] = average_ar self._process_lengths(lengths) diff --git a/examples/specdec_bench/specdec_bench/metrics/base.py b/examples/specdec_bench/specdec_bench/metrics/base.py index 3092aa8d3..62a2fbf2c 100644 --- a/examples/specdec_bench/specdec_bench/metrics/base.py +++ b/examples/specdec_bench/specdec_bench/metrics/base.py @@ -24,7 +24,7 @@ def __init__(self): self.out = {} self.name = "metric" - def process_step(self, step_outputs, new_turn=True): + def process_step(self, step_outputs, request_id, turn_id): raise NotImplementedError def process_final(self, text_outputs): diff --git a/examples/specdec_bench/specdec_bench/metrics/mtbench.py b/examples/specdec_bench/specdec_bench/metrics/mtbench.py index 2b6d8727b..0a7dd8b48 100644 --- a/examples/specdec_bench/specdec_bench/metrics/mtbench.py +++ b/examples/specdec_bench/specdec_bench/metrics/mtbench.py @@ -35,16 +35,16 @@ def process_final(self, text_outputs): i = 0 lengths = {} self.out["Request_AR"] = {} - while i < len(self.prompt_ar): - turn_1 = self.prompt_ar[i] - turn_2 = self.prompt_ar[i + 1] - q_id = i // 2 + self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0])) + for request_id, turns in self.prompt_ar.items(): + turn_1 = turns[0] + turn_2 = turns[1] + q_id = request_id mtbench_topic = MTBENCH_TOPICS[q_id // 10] - self.out["Request_AR"][q_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2) + self.out["Request_AR"][request_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2) self._get_lengths(turn_1, lengths) self._get_lengths(turn_2, lengths) print(mtbench_topic, sum(turn_1 + turn_2) / len(turn_1 + turn_2)) - i += 2 per_category = [[] for _ in range(len(MTBENCH_TOPICS))] for q_id, ar in self.out["Request_AR"].items(): per_category[q_id // 10].append(ar) diff --git a/examples/specdec_bench/specdec_bench/metrics/timing.py b/examples/specdec_bench/specdec_bench/metrics/timing.py index 270ea697c..023aaf785 100644 --- a/examples/specdec_bench/specdec_bench/metrics/timing.py +++ b/examples/specdec_bench/specdec_bench/metrics/timing.py @@ -26,7 +26,7 @@ def __init__(self, tp_size): self.total_tokens = [] self.tp_size = tp_size - def process_step(self, step_outputs, new_turn=True): + def process_step(self, step_outputs, request_id, turn_id): self.timing.append(step_outputs["token_times"]) self.total_tokens.append( sum([sum([len(j) for j in i]) for i in step_outputs["output_ids"]]) @@ -42,8 +42,9 @@ def process_final(self, text_outputs): self.out["Output TPS"] = sum(self.total_tokens) / (end_time - start_time) self.out["Output TPS/gpu"] = self.out["Output TPS"] / self.tp_size for tokens, times in zip(self.total_tokens, self.timing): - e2e_time.append(times[-1] - times[0]) - ttft_time.append(times[1] - times[0]) + if len(times) > 1: + e2e_time.append(times[-1] - times[0]) + ttft_time.append(times[1] - times[0]) if len(times) > 2: gen_tp_time.append((tokens - 1) / (times[-1] - times[1])) tpot_time.extend([a - b for a, b in zip(times[1:], times[:-1])]) diff --git a/examples/specdec_bench/specdec_bench/models/base.py b/examples/specdec_bench/specdec_bench/models/base.py index 5f3a9616a..42186fef0 100644 --- a/examples/specdec_bench/specdec_bench/models/base.py +++ b/examples/specdec_bench/specdec_bench/models/base.py @@ -18,7 +18,7 @@ class Model: def __init__(self, model_dir, tokenizer, max_draft_length): raise NotImplementedError - async def run(self, prompt_ids, max_length, end_id, request_id): + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): """ prompt_ids is list of tokens output is list of list of tokens diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index 534303569..4840a0eda 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -50,6 +50,7 @@ def __init__( speculative_num_steps=kwargs.get("speculative_num_steps", 3), speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1), speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4), + speculative_draft_model_path=kwargs.get("draft_model_dir"), torch_compile_max_bs=max_concurrent_requests, attention_backend=kwargs.get("attention_backend"), enable_torch_compile=kwargs.get("enable_torch_compile", False), @@ -70,7 +71,7 @@ def __init__( self.sampling_config = sampling_kwargs - async def run(self, prompt_ids, max_length, end_id, request_id): + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): timing = [] output_dict = {} self.sampling_config["max_new_tokens"] = max_length diff --git a/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py index 4d4e6c92c..11ceeb207 100644 --- a/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py +++ b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py @@ -43,7 +43,7 @@ def __init__( self.model = create_executor(model_path, max_concurrent_requests, kwargs) self.sampling_kwargs = sampling_kwargs - async def run(self, prompt_ids, max_length, end_id, request_id): + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): output_dict = {} sampling_config = check_sampling_config(self.sampling_kwargs, max_length, end_id) outputs = [] diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py index b0d4642d2..deb79ed89 100644 --- a/examples/specdec_bench/specdec_bench/models/vllm.py +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -84,12 +84,12 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - async def run(self, prompt_ids, max_length, end_id, request_id): + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): output_dict = {} self.sampling_config.max_tokens = max_length self.sampling_config.stop_token_ids = [end_id] - outputs, timing, full_tokens = await self.generate(prompt_ids, request_id) + outputs, timing, full_tokens = await self.generate(prompt_ids, request_id, turn_id) reformatted_output_ids = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] start = 0 @@ -114,13 +114,13 @@ async def run(self, prompt_ids, max_length, end_id, request_id): ] return output_dict - async def generate(self, prompt_ids, request_id): + async def generate(self, prompt_ids, request_id, turn_id): timing = [] timing.append(time.perf_counter()) outputs = [] full_tokens = [] async for output in self.model.generate( - request_id=str(request_id), + request_id=f"{request_id}.{turn_id}", prompt=TokensPrompt(prompt_token_ids=prompt_ids), sampling_params=self.sampling_config, ): diff --git a/examples/specdec_bench/specdec_bench/runners/base.py b/examples/specdec_bench/specdec_bench/runners/base.py index 794180ef5..c481a0fd0 100644 --- a/examples/specdec_bench/specdec_bench/runners/base.py +++ b/examples/specdec_bench/specdec_bench/runners/base.py @@ -21,14 +21,14 @@ def __init__(self, model, metrics): self.metrics = metrics self.prompt_ar = [] - async def run(self, prompt_ids, max_length, end_id, sampling_kwargs): + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): raise NotImplementedError() def process_metrics_final(self, text_outputs): [metric.process_final(text_outputs) for metric in self.metrics] - def process_metrics_step(self, step_outputs, new_turn=True): - [metric.process_step(step_outputs, new_turn) for metric in self.metrics] + def process_metrics_step(self, step_outputs, request_id, turn_id): + [metric.process_step(step_outputs, request_id, turn_id) for metric in self.metrics] def clear_metrics(self): [metric.clear() for metric in self.metrics] diff --git a/examples/specdec_bench/specdec_bench/runners/simple.py b/examples/specdec_bench/specdec_bench/runners/simple.py index 3b1458648..9d4b1a78b 100644 --- a/examples/specdec_bench/specdec_bench/runners/simple.py +++ b/examples/specdec_bench/specdec_bench/runners/simple.py @@ -23,9 +23,9 @@ def __init__(self, model, metrics): self.metrics = metrics self.prompt_ar = [] - async def run(self, prompt_ids, max_length, end_id, request_id): - model_output = await self.model.run(prompt_ids, max_length, end_id, request_id) - self.process_metrics_step(model_output) + async def run(self, prompt_ids, max_length, end_id, request_id, turn_id): + model_output = await self.model.run(prompt_ids, max_length, end_id, request_id, turn_id) + self.process_metrics_step(model_output, request_id, turn_id) output_ids = model_output["output_ids"] flattened_output_ids = [[] for _ in range(len(output_ids))] for i, beam_output in enumerate(output_ids):