Skip to content

Commit 5301917

Browse files
Adds a MBU calculation to our benchmark script (#957)
* Added MBU. * Updated mbu calculation. * Fixed issue. * Added logging to support test * Fixed MBU calculation.
1 parent f1d7223 commit 5301917

File tree

1 file changed

+238
-6
lines changed

1 file changed

+238
-6
lines changed

open_instruct/benchmark_generators.py

Lines changed: 238 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@
3131
from open_instruct.queue_types import PromptRequest
3232

3333
# For FLOPS, we assume bf16 and ignore sparsity.
34+
# Memory bandwidth values are peak theoretical bandwidth.
3435
GPU_SPECS = {
35-
"a100": {"flops": 312e12, "memory_size": 80e9},
36-
"b200": {"flops": 2250e12, "memory_size": 192e9},
37-
"h100": {"flops": 990e12, "memory_size": 80e9},
38-
"a6000": {"flops": 155e12, "memory_size": 48e9},
39-
"l40s": {"flops": 362e12, "memory_size": 48e9},
36+
"a100": {"flops": 312e12, "memory_size": 80e9, "memory_bandwidth": 1.6e12}, # 1.6 TB/s HBM2e
37+
"b200": {"flops": 2250e12, "memory_size": 192e9, "memory_bandwidth": 8e12}, # 8 TB/s HBM3e
38+
"h100": {"flops": 990e12, "memory_size": 80e9, "memory_bandwidth": 3.35e12}, # 3.35 TB/s HBM3
39+
"a6000": {"flops": 155e12, "memory_size": 48e9, "memory_bandwidth": 768e9}, # 768 GB/s GDDR6
40+
"l40s": {"flops": 362e12, "memory_size": 48e9, "memory_bandwidth": 864e9}, # 864 GB/s GDDR6
4041
}
4142

4243

@@ -146,6 +147,7 @@ def save_benchmark_results_to_csv(
146147
"total_tokens": agg_results["total_num_new_tokens"],
147148
"avg_tokens_per_second": agg_results["avg_tokens_per_second"],
148149
"avg_mfu": agg_results["avg_mfu"],
150+
"avg_mbu": agg_results["avg_mbu"],
149151
"avg_generation_time_per_batch": agg_results["avg_generation_time"],
150152
"avg_new_tokens_per_sample": agg_results["total_num_new_tokens"]
151153
/ (len(results) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout),
@@ -321,6 +323,215 @@ def flops(
321323
total += self.decode_flops(prompt_lengths, response_lengths, samples_per_prompt)
322324
return total
323325

326+
def weight_memory_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int:
327+
"""Memory bytes for reading model weights for a given number of tokens.
328+
329+
Args:
330+
num_tokens: Number of tokens to process
331+
dtype_bytes: Bytes per element (2 for FP16/BF16)
332+
333+
Returns:
334+
Total bytes for weight reads across all layers
335+
"""
336+
num_kv = self.num_kv_heads if self.num_kv_heads is not None else self.num_attn_heads
337+
head_dim = self.hidden_size // self.num_attn_heads
338+
hidden_kv = num_kv * head_dim
339+
340+
# Per-layer weight params (Q, K, V, O, MLP up, MLP down)
341+
w_q = self.hidden_size * self.hidden_size
342+
w_k = self.hidden_size * hidden_kv
343+
w_v = self.hidden_size * hidden_kv
344+
w_o = self.hidden_size * self.hidden_size
345+
w_up = self.hidden_size * self.intermediate_size
346+
w_dn = self.intermediate_size * self.hidden_size
347+
348+
per_layer_weight_bytes = (w_q + w_k + w_v + w_o + w_up + w_dn) * dtype_bytes
349+
return self.num_layers * num_tokens * per_layer_weight_bytes
350+
351+
def kv_cache_write_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int:
352+
"""Memory bytes for writing KV cache for a given number of tokens.
353+
354+
Args:
355+
num_tokens: Number of tokens being cached
356+
dtype_bytes: Bytes per element (2 for FP16/BF16)
357+
358+
Returns:
359+
Total bytes for KV cache writes across all layers
360+
"""
361+
num_kv = self.num_kv_heads if self.num_kv_heads is not None else self.num_attn_heads
362+
head_dim = self.hidden_size // self.num_attn_heads
363+
364+
# 2x for K and V
365+
kv_write_bytes_per_token = 2 * num_kv * head_dim * dtype_bytes
366+
return self.num_layers * num_tokens * kv_write_bytes_per_token
367+
368+
def kv_cache_read_bytes(
369+
self,
370+
prompt_lengths: Sequence[int],
371+
response_lengths: Sequence[int],
372+
samples_per_prompt: int = 1,
373+
dtype_bytes: int = 2,
374+
) -> int:
375+
"""Memory bytes for reading KV cache during decode.
376+
377+
For each new token generated, we read all previous tokens' KV cache.
378+
When generating multiple samples per prompt, the prompt KV cache is shared.
379+
380+
Args:
381+
prompt_lengths: List of prompt lengths (one per unique prompt)
382+
response_lengths: List of response lengths (samples_per_prompt * len(prompt_lengths) total)
383+
samples_per_prompt: Number of samples generated per prompt
384+
dtype_bytes: Bytes per element (2 for FP16/BF16)
385+
386+
Returns:
387+
Total bytes for KV cache reads during decode
388+
"""
389+
assert len(response_lengths) == len(prompt_lengths) * samples_per_prompt, (
390+
f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}"
391+
)
392+
393+
num_kv = self.num_kv_heads if self.num_kv_heads is not None else self.num_attn_heads
394+
head_dim = self.hidden_size // self.num_attn_heads
395+
396+
# For batched sampling with shared prompt KV cache:
397+
# - Prompt KV is read once per new token position across ALL samples (not per sample)
398+
# - Each sample has its own KV for generated tokens
399+
kv_read_terms = 0
400+
response_idx = 0
401+
402+
for P in prompt_lengths:
403+
# For this prompt, collect all response lengths
404+
prompt_responses = []
405+
for _ in range(samples_per_prompt):
406+
prompt_responses.append(response_lengths[response_idx])
407+
response_idx += 1
408+
409+
# Prompt KV reads: In synchronized batch generation with vLLM n>1,
410+
# the prompt KV cache is stored once but each sample reads it independently.
411+
# At each decoding position, each sample reads the prompt KV cache.
412+
# Number of positions = max response length (all generate synchronously)
413+
max_response_length = max(prompt_responses) if prompt_responses else 0
414+
# Each of the samples_per_prompt samples reads prompt KV at each position
415+
kv_read_terms += max_response_length * samples_per_prompt * P
416+
417+
# Per-sample generated KV reads: Each sample reads its own previously generated tokens
418+
for R in prompt_responses:
419+
# Each token in this sample reads its previously generated tokens
420+
# sum_{i=0}^{R-1} i = R*(R-1)/2
421+
kv_read_terms += R * (R - 1) // 2
422+
423+
# 2x for K and V
424+
kv_bytes_per_token = 2 * num_kv * head_dim * dtype_bytes
425+
return self.num_layers * kv_bytes_per_token * kv_read_terms
426+
427+
def prefill_memory_bytes(self, prompt_lengths: Sequence[int], dtype_bytes: int = 2) -> int:
428+
"""Memory bytes for prefill phase.
429+
430+
During prefill:
431+
- Read weights once for the entire batch (batched matmul)
432+
- Write KV cache for each token
433+
434+
Args:
435+
prompt_lengths: List of prompt lengths
436+
dtype_bytes: Bytes per element (2 for FP16/BF16)
437+
438+
Returns:
439+
Total memory bytes for prefill
440+
"""
441+
# In batched prefill, weights are read once for the entire operation,
442+
# not once per token. We process all prompts in a single batch.
443+
num_prefill_batches = len(prompt_lengths) # Each prompt is a "batch"
444+
weight_bytes = self.weight_memory_bytes(num_prefill_batches, dtype_bytes)
445+
446+
# KV cache is written for every token
447+
total_prefill_tokens = sum(prompt_lengths)
448+
kv_write_bytes = self.kv_cache_write_bytes(total_prefill_tokens, dtype_bytes)
449+
return weight_bytes + kv_write_bytes
450+
451+
def decode_memory_bytes(
452+
self,
453+
prompt_lengths: Sequence[int],
454+
response_lengths: Sequence[int],
455+
samples_per_prompt: int = 1,
456+
dtype_bytes: int = 2,
457+
) -> int:
458+
"""Memory bytes for decode/generation phase.
459+
460+
During decode:
461+
- Read weights for each new token position (shared across samples in batch)
462+
- Write KV cache for each new token
463+
- Read all previous KV cache for attention
464+
465+
Args:
466+
prompt_lengths: List of prompt lengths (one per unique prompt)
467+
response_lengths: List of response lengths (samples_per_prompt * len(prompt_lengths) total)
468+
samples_per_prompt: Number of samples generated per prompt
469+
dtype_bytes: Bytes per element (2 for FP16/BF16)
470+
471+
Returns:
472+
Total memory bytes for decode
473+
"""
474+
# In synchronized batch generation, weights are read once per position,
475+
# not once per token. With multiple samples per prompt generating in parallel,
476+
# we only need to read weights for the number of unique positions.
477+
unique_positions = 0
478+
response_idx = 0
479+
for _ in prompt_lengths:
480+
# Get response lengths for this prompt's samples
481+
prompt_responses = response_lengths[response_idx : response_idx + samples_per_prompt]
482+
response_idx += samples_per_prompt
483+
# In synchronized generation, all samples generate the same number of positions
484+
# (up to the max length among them)
485+
unique_positions += max(prompt_responses) if prompt_responses else 0
486+
487+
weight_bytes = self.weight_memory_bytes(unique_positions, dtype_bytes)
488+
489+
# KV writes happen for all tokens (each sample writes its own KV)
490+
total_decode_tokens = sum(response_lengths)
491+
kv_write_bytes = self.kv_cache_write_bytes(total_decode_tokens, dtype_bytes)
492+
493+
kv_read_bytes = self.kv_cache_read_bytes(prompt_lengths, response_lengths, samples_per_prompt, dtype_bytes)
494+
return weight_bytes + kv_write_bytes + kv_read_bytes
495+
496+
def memory_bytes(
497+
self,
498+
prompt_lengths: Sequence[int],
499+
response_lengths: Optional[Sequence[int]] = None,
500+
samples_per_prompt: int = 1,
501+
dtype_bytes: int = 2,
502+
) -> int:
503+
"""Approximate total HBM bytes moved for prefill + decode.
504+
505+
Returns an integer number of bytes. Divide by elapsed seconds to get B/s;
506+
compare against peak bandwidth to get utilization.
507+
508+
Args:
509+
prompt_lengths: List of prompt lengths (one per unique prompt)
510+
response_lengths: List of response lengths (samples_per_prompt * len(prompt_lengths) total)
511+
samples_per_prompt: Number of samples generated per prompt
512+
dtype_bytes: Bytes per element (2 for FP16/BF16)
513+
514+
Returns:
515+
Total memory bytes moved
516+
517+
Assumptions:
518+
- Weights are read once per token per layer (Q,K,V,O + MLP up/down)
519+
- KV cache: write K/V for every token; during decode, read all past K/V per new token
520+
- When batching samples, prompt KV cache is shared across samples
521+
- Embedding and LM head reads are ignored (usually dominated by matmul weight traffic)
522+
"""
523+
total = self.prefill_memory_bytes(prompt_lengths, dtype_bytes)
524+
525+
if response_lengths is not None:
526+
assert len(response_lengths) == len(prompt_lengths) * samples_per_prompt, (
527+
f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}"
528+
)
529+
530+
# Pass original prompt_lengths with samples_per_prompt to correctly handle shared KV cache
531+
total += self.decode_memory_bytes(prompt_lengths, response_lengths, samples_per_prompt, dtype_bytes)
532+
533+
return total
534+
324535

325536
def load_model_dims(model_name: str) -> ModelDims:
326537
cfg = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@@ -507,6 +718,7 @@ def run_benchmark(
507718
results = []
508719
device_name = get_device_name(torch.cuda.get_device_name(0))
509720
device_flops = GPU_SPECS[device_name]["flops"]
721+
device_memory_bandwidth = GPU_SPECS[device_name]["memory_bandwidth"]
510722

511723
# Submit warmup batch first
512724
logger.info("Submitting warmup batch...")
@@ -578,12 +790,22 @@ def run_benchmark(
578790
model_flops_per_second = model_flops / batch_generation_time if batch_generation_time > 0 else 0
579791
result_dict["mfu"] = 100 * model_flops_per_second / device_flops
580792

793+
# Calculate total memory bytes for all prompts and responses in the batch
794+
model_memory_bytes = model_dims.memory_bytes(
795+
prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
796+
)
797+
798+
# MBU = (Memory bytes / time) / peak_bandwidth * 100
799+
model_bytes_per_second = model_memory_bytes / batch_generation_time if batch_generation_time > 0 else 0
800+
result_dict["mbu"] = 100 * model_bytes_per_second / device_memory_bandwidth
801+
581802
save_completion_lengths([result_dict], timestamp, batch_idx)
582803
results.append(result_dict)
583804
logger.info(
584805
f"Batch {batch_idx}/{num_batches - 1}: "
585806
f"{result_dict['tokens_per_second']:.2f} new tokens/sec, "
586807
f"MFU: {result_dict['mfu']:.2f}%, "
808+
f"MBU: {result_dict['mbu']:.2f}%, "
587809
f"generation time: {batch_generation_time:.2f}s, "
588810
f"total new tokens: {new_tokens}"
589811
)
@@ -604,6 +826,7 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
604826
"""Calculate total and aggregated metrics from results."""
605827
aggregated_results = {
606828
"total_mfu": 0.0,
829+
"total_mbu": 0.0,
607830
"total_tokens_per_second": 0.0,
608831
"total_generation_time": 0.0,
609832
"total_num_new_tokens": 0,
@@ -615,6 +838,8 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
615838
for key, value in result.items():
616839
if key == "mfu":
617840
aggregated_results["total_mfu"] += value
841+
elif key == "mbu":
842+
aggregated_results["total_mbu"] += value
618843
elif key == "tokens_per_second":
619844
aggregated_results["total_tokens_per_second"] += value
620845
elif key == "generation_time":
@@ -628,8 +853,13 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
628853
aggregated_results[key].extend(value)
629854

630855
num_results = len(results)
631-
aggregated_results["avg_tokens_per_second"] = aggregated_results["total_tokens_per_second"] / num_results
856+
aggregated_results["avg_tokens_per_second"] = (
857+
aggregated_results["total_num_new_tokens"] / aggregated_results["total_generation_time"]
858+
if aggregated_results["total_generation_time"] > 0
859+
else 0
860+
)
632861
aggregated_results["avg_mfu"] = aggregated_results["total_mfu"] / num_results
862+
aggregated_results["avg_mbu"] = aggregated_results["total_mbu"] / num_results
633863
aggregated_results["avg_generation_time"] = aggregated_results["total_generation_time"] / num_results
634864
return aggregated_results
635865

@@ -659,6 +889,7 @@ def print_summary(
659889
print(f"Average results over {len(results)} main benchmark batches:")
660890
print(f"Average tokens/second: {agg_results['avg_tokens_per_second']:.2f}")
661891
print(f"Average MFU: {agg_results['avg_mfu']:.2f}%")
892+
print(f"Average MBU: {agg_results['avg_mbu']:.2f}%")
662893
print(f"Average generation time per batch: {agg_results['avg_generation_time']:.2f}s")
663894
print(f"Average new tokens per sample: {avg_new_tokens_per_sample:.2f} tokens")
664895

@@ -673,6 +904,7 @@ def print_summary(
673904
print(f"GPU device: {torch.cuda.get_device_name(0)}")
674905
print(f"GPU peak FLOPs: {gpu_specs['flops'] / 1e12:.0f} TFLOPs")
675906
print(f"GPU memory size: {gpu_specs['memory_size'] / 1e9:.0f} GB")
907+
print(f"GPU memory bandwidth: {gpu_specs['memory_bandwidth'] / 1e12:.2f} TB/s")
676908

677909
print("-" * 60)
678910
print("COMPLETION LENGTH STATISTICS:")

0 commit comments

Comments
 (0)