Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/specdec_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ async def process_single_request(request, i):
if request.system_prompt is not None:
messages.append({"role": "system", "content": request.system_prompt})

for question in request.turns:
for turn_id, question in enumerate(request.turns):
messages.append({"role": "user", "content": question})
entry_encoded = encode_chat(tokenizer, messages)

# Run the async runner.run directly
output_tokens = await runner.run(entry_encoded, max_length, end_id, i)
output_tokens = await runner.run(
entry_encoded, max_length, end_id, request_id=i, turn_id=turn_id
)
output_text = decode_chat(tokenizer, output_tokens["output_ids"][0])
output_text = postprocess(output_text)
messages.append({"role": "assistant", "content": output_text})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, base_tokenizer):
self.base_tokenizer = base_tokenizer
self.total_tokens = []

def process_step(self, step_outputs, new_turn=True):
def process_step(self, step_outputs, request_id, turn_id):
self.timing.append(step_outputs["token_times"])
target_tokens = [
t for tok_list in step_outputs["output_ids"] for tok in tok_list for t in tok
Expand Down
31 changes: 18 additions & 13 deletions examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
class AcceptanceRate(Metric):
def __init__(self):
super().__init__()
self.prompt_ar = []
self.prompt_ar = {}
self.name = "acceptance_rate"

def process_step(self, step_outputs, new_turn=True):
if new_turn:
self.prompt_ar.append([])
def process_step(self, step_outputs, request_id, turn_id):
if request_id not in self.prompt_ar:
self.prompt_ar[request_id] = {}
if turn_id not in self.prompt_ar[request_id]:
self.prompt_ar[request_id][turn_id] = []
for i, beam_output in enumerate(step_outputs["output_ids"]):
for output_id_iter in beam_output:
self.prompt_ar[-1].append(len(output_id_iter))
self.prompt_ar[request_id][turn_id].append(len(output_id_iter))

def _get_lengths(self, turn, lengths):
for j in turn:
Expand All @@ -55,16 +57,19 @@ def _process_lengths(self, lengths):
running_len -= v

def process_final(self, text_outputs):
i = 0
all_ar = []
lengths = {}
self.out["Request_AR"] = {}
while i < len(self.prompt_ar):
turn_1 = self.prompt_ar[i]
self.out["Request_AR"][i] = sum(turn_1) / len(turn_1)
self._get_lengths(turn_1, lengths)
print(i, self.out["Request_AR"][i])
i += 1
average_ar = sum(self.out["Request_AR"].values()) / len(self.out["Request_AR"])
self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0]))
for request_id, turns in self.prompt_ar.items():
self.out["Request_AR"][request_id] = {}
for turn_id, turn in turns.items():
ar = sum(turn) / len(turn)
self.out["Request_AR"][request_id][turn_id] = ar
all_ar.append(ar)
self._get_lengths(turn, lengths)
print(request_id, turn_id, self.out["Request_AR"][request_id][turn_id])
average_ar = sum(all_ar) / len(all_ar)
print("Average AR:", average_ar)
self.out["Average_AR"] = average_ar
self._process_lengths(lengths)
Expand Down
2 changes: 1 addition & 1 deletion examples/specdec_bench/specdec_bench/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self):
self.out = {}
self.name = "metric"

def process_step(self, step_outputs, new_turn=True):
def process_step(self, step_outputs, request_id, turn_id):
raise NotImplementedError

def process_final(self, text_outputs):
Expand Down
12 changes: 6 additions & 6 deletions examples/specdec_bench/specdec_bench/metrics/mtbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def process_final(self, text_outputs):
i = 0
lengths = {}
self.out["Request_AR"] = {}
while i < len(self.prompt_ar):
turn_1 = self.prompt_ar[i]
turn_2 = self.prompt_ar[i + 1]
q_id = i // 2
self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0]))
for request_id, turns in self.prompt_ar.items():
turn_1 = turns[0]
turn_2 = turns[1]
q_id = request_id
mtbench_topic = MTBENCH_TOPICS[q_id // 10]
self.out["Request_AR"][q_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2)
self.out["Request_AR"][request_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2)
self._get_lengths(turn_1, lengths)
self._get_lengths(turn_2, lengths)
print(mtbench_topic, sum(turn_1 + turn_2) / len(turn_1 + turn_2))
i += 2
per_category = [[] for _ in range(len(MTBENCH_TOPICS))]
for q_id, ar in self.out["Request_AR"].items():
per_category[q_id // 10].append(ar)
Expand Down
7 changes: 4 additions & 3 deletions examples/specdec_bench/specdec_bench/metrics/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, tp_size):
self.total_tokens = []
self.tp_size = tp_size

def process_step(self, step_outputs, new_turn=True):
def process_step(self, step_outputs, request_id, turn_id):
self.timing.append(step_outputs["token_times"])
self.total_tokens.append(
sum([sum([len(j) for j in i]) for i in step_outputs["output_ids"]])
Expand All @@ -42,8 +42,9 @@ def process_final(self, text_outputs):
self.out["Output TPS"] = sum(self.total_tokens) / (end_time - start_time)
self.out["Output TPS/gpu"] = self.out["Output TPS"] / self.tp_size
for tokens, times in zip(self.total_tokens, self.timing):
e2e_time.append(times[-1] - times[0])
ttft_time.append(times[1] - times[0])
if len(times) > 1:
e2e_time.append(times[-1] - times[0])
ttft_time.append(times[1] - times[0])
if len(times) > 2:
gen_tp_time.append((tokens - 1) / (times[-1] - times[1]))
tpot_time.extend([a - b for a, b in zip(times[1:], times[:-1])])
Expand Down
2 changes: 1 addition & 1 deletion examples/specdec_bench/specdec_bench/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Model:
def __init__(self, model_dir, tokenizer, max_draft_length):
raise NotImplementedError

async def run(self, prompt_ids, max_length, end_id, request_id):
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
"""
prompt_ids is list of tokens
output is list of list of tokens
Expand Down
3 changes: 2 additions & 1 deletion examples/specdec_bench/specdec_bench/models/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
speculative_num_steps=kwargs.get("speculative_num_steps", 3),
speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1),
speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4),
speculative_draft_model_path=kwargs.get("draft_model_dir"),
torch_compile_max_bs=max_concurrent_requests,
attention_backend=kwargs.get("attention_backend"),
enable_torch_compile=kwargs.get("enable_torch_compile", False),
Expand All @@ -70,7 +71,7 @@ def __init__(

self.sampling_config = sampling_kwargs

async def run(self, prompt_ids, max_length, end_id, request_id):
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
timing = []
output_dict = {}
self.sampling_config["max_new_tokens"] = max_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.model = create_executor(model_path, max_concurrent_requests, kwargs)
self.sampling_kwargs = sampling_kwargs

async def run(self, prompt_ids, max_length, end_id, request_id):
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
output_dict = {}
sampling_config = check_sampling_config(self.sampling_kwargs, max_length, end_id)
outputs = []
Expand Down
8 changes: 4 additions & 4 deletions examples/specdec_bench/specdec_bench/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

async def run(self, prompt_ids, max_length, end_id, request_id):
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
output_dict = {}
self.sampling_config.max_tokens = max_length
self.sampling_config.stop_token_ids = [end_id]

outputs, timing, full_tokens = await self.generate(prompt_ids, request_id)
outputs, timing, full_tokens = await self.generate(prompt_ids, request_id, turn_id)

reformatted_output_ids = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))]
start = 0
Expand All @@ -114,13 +114,13 @@ async def run(self, prompt_ids, max_length, end_id, request_id):
]
return output_dict

async def generate(self, prompt_ids, request_id):
async def generate(self, prompt_ids, request_id, turn_id):
timing = []
timing.append(time.perf_counter())
outputs = []
full_tokens = []
async for output in self.model.generate(
request_id=str(request_id),
request_id=f"{request_id}.{turn_id}",
prompt=TokensPrompt(prompt_token_ids=prompt_ids),
sampling_params=self.sampling_config,
):
Expand Down
6 changes: 3 additions & 3 deletions examples/specdec_bench/specdec_bench/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def __init__(self, model, metrics):
self.metrics = metrics
self.prompt_ar = []

async def run(self, prompt_ids, max_length, end_id, sampling_kwargs):
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
raise NotImplementedError()

def process_metrics_final(self, text_outputs):
[metric.process_final(text_outputs) for metric in self.metrics]

def process_metrics_step(self, step_outputs, new_turn=True):
[metric.process_step(step_outputs, new_turn) for metric in self.metrics]
def process_metrics_step(self, step_outputs, request_id, turn_id):
[metric.process_step(step_outputs, request_id, turn_id) for metric in self.metrics]

def clear_metrics(self):
[metric.clear() for metric in self.metrics]
Expand Down
6 changes: 3 additions & 3 deletions examples/specdec_bench/specdec_bench/runners/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(self, model, metrics):
self.metrics = metrics
self.prompt_ar = []

async def run(self, prompt_ids, max_length, end_id, request_id):
model_output = await self.model.run(prompt_ids, max_length, end_id, request_id)
self.process_metrics_step(model_output)
async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
model_output = await self.model.run(prompt_ids, max_length, end_id, request_id, turn_id)
self.process_metrics_step(model_output, request_id, turn_id)
output_ids = model_output["output_ids"]
flattened_output_ids = [[] for _ in range(len(output_ids))]
for i, beam_output in enumerate(output_ids):
Expand Down