3131from open_instruct .queue_types import PromptRequest
3232
3333# For FLOPS, we assume bf16 and ignore sparsity.
34+ # Memory bandwidth values are peak theoretical bandwidth.
3435GPU_SPECS = {
35- "a100" : {"flops" : 312e12 , "memory_size" : 80e9 },
36- "b200" : {"flops" : 2250e12 , "memory_size" : 192e9 },
37- "h100" : {"flops" : 990e12 , "memory_size" : 80e9 },
38- "a6000" : {"flops" : 155e12 , "memory_size" : 48e9 },
39- "l40s" : {"flops" : 362e12 , "memory_size" : 48e9 },
36+ "a100" : {"flops" : 312e12 , "memory_size" : 80e9 , "memory_bandwidth" : 1.6e12 }, # 1.6 TB/s HBM2e
37+ "b200" : {"flops" : 2250e12 , "memory_size" : 192e9 , "memory_bandwidth" : 8e12 }, # 8 TB/s HBM3e
38+ "h100" : {"flops" : 990e12 , "memory_size" : 80e9 , "memory_bandwidth" : 3.35e12 }, # 3.35 TB/s HBM3
39+ "a6000" : {"flops" : 155e12 , "memory_size" : 48e9 , "memory_bandwidth" : 768e9 }, # 768 GB/s GDDR6
40+ "l40s" : {"flops" : 362e12 , "memory_size" : 48e9 , "memory_bandwidth" : 864e9 }, # 864 GB/s GDDR6
4041}
4142
4243
@@ -146,6 +147,7 @@ def save_benchmark_results_to_csv(
146147 "total_tokens" : agg_results ["total_num_new_tokens" ],
147148 "avg_tokens_per_second" : agg_results ["avg_tokens_per_second" ],
148149 "avg_mfu" : agg_results ["avg_mfu" ],
150+ "avg_mbu" : agg_results ["avg_mbu" ],
149151 "avg_generation_time_per_batch" : agg_results ["avg_generation_time" ],
150152 "avg_new_tokens_per_sample" : agg_results ["total_num_new_tokens" ]
151153 / (len (results ) * args .num_unique_prompts_rollout * args .num_samples_per_prompt_rollout ),
@@ -321,6 +323,215 @@ def flops(
321323 total += self .decode_flops (prompt_lengths , response_lengths , samples_per_prompt )
322324 return total
323325
326+ def weight_memory_bytes (self , num_tokens : int , dtype_bytes : int = 2 ) -> int :
327+ """Memory bytes for reading model weights for a given number of tokens.
328+
329+ Args:
330+ num_tokens: Number of tokens to process
331+ dtype_bytes: Bytes per element (2 for FP16/BF16)
332+
333+ Returns:
334+ Total bytes for weight reads across all layers
335+ """
336+ num_kv = self .num_kv_heads if self .num_kv_heads is not None else self .num_attn_heads
337+ head_dim = self .hidden_size // self .num_attn_heads
338+ hidden_kv = num_kv * head_dim
339+
340+ # Per-layer weight params (Q, K, V, O, MLP up, MLP down)
341+ w_q = self .hidden_size * self .hidden_size
342+ w_k = self .hidden_size * hidden_kv
343+ w_v = self .hidden_size * hidden_kv
344+ w_o = self .hidden_size * self .hidden_size
345+ w_up = self .hidden_size * self .intermediate_size
346+ w_dn = self .intermediate_size * self .hidden_size
347+
348+ per_layer_weight_bytes = (w_q + w_k + w_v + w_o + w_up + w_dn ) * dtype_bytes
349+ return self .num_layers * num_tokens * per_layer_weight_bytes
350+
351+ def kv_cache_write_bytes (self , num_tokens : int , dtype_bytes : int = 2 ) -> int :
352+ """Memory bytes for writing KV cache for a given number of tokens.
353+
354+ Args:
355+ num_tokens: Number of tokens being cached
356+ dtype_bytes: Bytes per element (2 for FP16/BF16)
357+
358+ Returns:
359+ Total bytes for KV cache writes across all layers
360+ """
361+ num_kv = self .num_kv_heads if self .num_kv_heads is not None else self .num_attn_heads
362+ head_dim = self .hidden_size // self .num_attn_heads
363+
364+ # 2x for K and V
365+ kv_write_bytes_per_token = 2 * num_kv * head_dim * dtype_bytes
366+ return self .num_layers * num_tokens * kv_write_bytes_per_token
367+
368+ def kv_cache_read_bytes (
369+ self ,
370+ prompt_lengths : Sequence [int ],
371+ response_lengths : Sequence [int ],
372+ samples_per_prompt : int = 1 ,
373+ dtype_bytes : int = 2 ,
374+ ) -> int :
375+ """Memory bytes for reading KV cache during decode.
376+
377+ For each new token generated, we read all previous tokens' KV cache.
378+ When generating multiple samples per prompt, the prompt KV cache is shared.
379+
380+ Args:
381+ prompt_lengths: List of prompt lengths (one per unique prompt)
382+ response_lengths: List of response lengths (samples_per_prompt * len(prompt_lengths) total)
383+ samples_per_prompt: Number of samples generated per prompt
384+ dtype_bytes: Bytes per element (2 for FP16/BF16)
385+
386+ Returns:
387+ Total bytes for KV cache reads during decode
388+ """
389+ assert len (response_lengths ) == len (prompt_lengths ) * samples_per_prompt , (
390+ f"Expected { len (prompt_lengths ) * samples_per_prompt } response lengths, got { len (response_lengths )} "
391+ )
392+
393+ num_kv = self .num_kv_heads if self .num_kv_heads is not None else self .num_attn_heads
394+ head_dim = self .hidden_size // self .num_attn_heads
395+
396+ # For batched sampling with shared prompt KV cache:
397+ # - Prompt KV is read once per new token position across ALL samples (not per sample)
398+ # - Each sample has its own KV for generated tokens
399+ kv_read_terms = 0
400+ response_idx = 0
401+
402+ for P in prompt_lengths :
403+ # For this prompt, collect all response lengths
404+ prompt_responses = []
405+ for _ in range (samples_per_prompt ):
406+ prompt_responses .append (response_lengths [response_idx ])
407+ response_idx += 1
408+
409+ # Prompt KV reads: In synchronized batch generation with vLLM n>1,
410+ # the prompt KV cache is stored once but each sample reads it independently.
411+ # At each decoding position, each sample reads the prompt KV cache.
412+ # Number of positions = max response length (all generate synchronously)
413+ max_response_length = max (prompt_responses ) if prompt_responses else 0
414+ # Each of the samples_per_prompt samples reads prompt KV at each position
415+ kv_read_terms += max_response_length * samples_per_prompt * P
416+
417+ # Per-sample generated KV reads: Each sample reads its own previously generated tokens
418+ for R in prompt_responses :
419+ # Each token in this sample reads its previously generated tokens
420+ # sum_{i=0}^{R-1} i = R*(R-1)/2
421+ kv_read_terms += R * (R - 1 ) // 2
422+
423+ # 2x for K and V
424+ kv_bytes_per_token = 2 * num_kv * head_dim * dtype_bytes
425+ return self .num_layers * kv_bytes_per_token * kv_read_terms
426+
427+ def prefill_memory_bytes (self , prompt_lengths : Sequence [int ], dtype_bytes : int = 2 ) -> int :
428+ """Memory bytes for prefill phase.
429+
430+ During prefill:
431+ - Read weights once for the entire batch (batched matmul)
432+ - Write KV cache for each token
433+
434+ Args:
435+ prompt_lengths: List of prompt lengths
436+ dtype_bytes: Bytes per element (2 for FP16/BF16)
437+
438+ Returns:
439+ Total memory bytes for prefill
440+ """
441+ # In batched prefill, weights are read once for the entire operation,
442+ # not once per token. We process all prompts in a single batch.
443+ num_prefill_batches = len (prompt_lengths ) # Each prompt is a "batch"
444+ weight_bytes = self .weight_memory_bytes (num_prefill_batches , dtype_bytes )
445+
446+ # KV cache is written for every token
447+ total_prefill_tokens = sum (prompt_lengths )
448+ kv_write_bytes = self .kv_cache_write_bytes (total_prefill_tokens , dtype_bytes )
449+ return weight_bytes + kv_write_bytes
450+
451+ def decode_memory_bytes (
452+ self ,
453+ prompt_lengths : Sequence [int ],
454+ response_lengths : Sequence [int ],
455+ samples_per_prompt : int = 1 ,
456+ dtype_bytes : int = 2 ,
457+ ) -> int :
458+ """Memory bytes for decode/generation phase.
459+
460+ During decode:
461+ - Read weights for each new token position (shared across samples in batch)
462+ - Write KV cache for each new token
463+ - Read all previous KV cache for attention
464+
465+ Args:
466+ prompt_lengths: List of prompt lengths (one per unique prompt)
467+ response_lengths: List of response lengths (samples_per_prompt * len(prompt_lengths) total)
468+ samples_per_prompt: Number of samples generated per prompt
469+ dtype_bytes: Bytes per element (2 for FP16/BF16)
470+
471+ Returns:
472+ Total memory bytes for decode
473+ """
474+ # In synchronized batch generation, weights are read once per position,
475+ # not once per token. With multiple samples per prompt generating in parallel,
476+ # we only need to read weights for the number of unique positions.
477+ unique_positions = 0
478+ response_idx = 0
479+ for _ in prompt_lengths :
480+ # Get response lengths for this prompt's samples
481+ prompt_responses = response_lengths [response_idx : response_idx + samples_per_prompt ]
482+ response_idx += samples_per_prompt
483+ # In synchronized generation, all samples generate the same number of positions
484+ # (up to the max length among them)
485+ unique_positions += max (prompt_responses ) if prompt_responses else 0
486+
487+ weight_bytes = self .weight_memory_bytes (unique_positions , dtype_bytes )
488+
489+ # KV writes happen for all tokens (each sample writes its own KV)
490+ total_decode_tokens = sum (response_lengths )
491+ kv_write_bytes = self .kv_cache_write_bytes (total_decode_tokens , dtype_bytes )
492+
493+ kv_read_bytes = self .kv_cache_read_bytes (prompt_lengths , response_lengths , samples_per_prompt , dtype_bytes )
494+ return weight_bytes + kv_write_bytes + kv_read_bytes
495+
496+ def memory_bytes (
497+ self ,
498+ prompt_lengths : Sequence [int ],
499+ response_lengths : Optional [Sequence [int ]] = None ,
500+ samples_per_prompt : int = 1 ,
501+ dtype_bytes : int = 2 ,
502+ ) -> int :
503+ """Approximate total HBM bytes moved for prefill + decode.
504+
505+ Returns an integer number of bytes. Divide by elapsed seconds to get B/s;
506+ compare against peak bandwidth to get utilization.
507+
508+ Args:
509+ prompt_lengths: List of prompt lengths (one per unique prompt)
510+ response_lengths: List of response lengths (samples_per_prompt * len(prompt_lengths) total)
511+ samples_per_prompt: Number of samples generated per prompt
512+ dtype_bytes: Bytes per element (2 for FP16/BF16)
513+
514+ Returns:
515+ Total memory bytes moved
516+
517+ Assumptions:
518+ - Weights are read once per token per layer (Q,K,V,O + MLP up/down)
519+ - KV cache: write K/V for every token; during decode, read all past K/V per new token
520+ - When batching samples, prompt KV cache is shared across samples
521+ - Embedding and LM head reads are ignored (usually dominated by matmul weight traffic)
522+ """
523+ total = self .prefill_memory_bytes (prompt_lengths , dtype_bytes )
524+
525+ if response_lengths is not None :
526+ assert len (response_lengths ) == len (prompt_lengths ) * samples_per_prompt , (
527+ f"Expected { len (prompt_lengths ) * samples_per_prompt } response lengths, got { len (response_lengths )} "
528+ )
529+
530+ # Pass original prompt_lengths with samples_per_prompt to correctly handle shared KV cache
531+ total += self .decode_memory_bytes (prompt_lengths , response_lengths , samples_per_prompt , dtype_bytes )
532+
533+ return total
534+
324535
325536def load_model_dims (model_name : str ) -> ModelDims :
326537 cfg = transformers .AutoConfig .from_pretrained (model_name , trust_remote_code = True )
@@ -507,6 +718,7 @@ def run_benchmark(
507718 results = []
508719 device_name = get_device_name (torch .cuda .get_device_name (0 ))
509720 device_flops = GPU_SPECS [device_name ]["flops" ]
721+ device_memory_bandwidth = GPU_SPECS [device_name ]["memory_bandwidth" ]
510722
511723 # Submit warmup batch first
512724 logger .info ("Submitting warmup batch..." )
@@ -578,12 +790,22 @@ def run_benchmark(
578790 model_flops_per_second = model_flops / batch_generation_time if batch_generation_time > 0 else 0
579791 result_dict ["mfu" ] = 100 * model_flops_per_second / device_flops
580792
793+ # Calculate total memory bytes for all prompts and responses in the batch
794+ model_memory_bytes = model_dims .memory_bytes (
795+ prompt_lengths , response_lengths , samples_per_prompt = args .num_samples_per_prompt_rollout
796+ )
797+
798+ # MBU = (Memory bytes / time) / peak_bandwidth * 100
799+ model_bytes_per_second = model_memory_bytes / batch_generation_time if batch_generation_time > 0 else 0
800+ result_dict ["mbu" ] = 100 * model_bytes_per_second / device_memory_bandwidth
801+
581802 save_completion_lengths ([result_dict ], timestamp , batch_idx )
582803 results .append (result_dict )
583804 logger .info (
584805 f"Batch { batch_idx } /{ num_batches - 1 } : "
585806 f"{ result_dict ['tokens_per_second' ]:.2f} new tokens/sec, "
586807 f"MFU: { result_dict ['mfu' ]:.2f} %, "
808+ f"MBU: { result_dict ['mbu' ]:.2f} %, "
587809 f"generation time: { batch_generation_time :.2f} s, "
588810 f"total new tokens: { new_tokens } "
589811 )
@@ -604,6 +826,7 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
604826 """Calculate total and aggregated metrics from results."""
605827 aggregated_results = {
606828 "total_mfu" : 0.0 ,
829+ "total_mbu" : 0.0 ,
607830 "total_tokens_per_second" : 0.0 ,
608831 "total_generation_time" : 0.0 ,
609832 "total_num_new_tokens" : 0 ,
@@ -615,6 +838,8 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
615838 for key , value in result .items ():
616839 if key == "mfu" :
617840 aggregated_results ["total_mfu" ] += value
841+ elif key == "mbu" :
842+ aggregated_results ["total_mbu" ] += value
618843 elif key == "tokens_per_second" :
619844 aggregated_results ["total_tokens_per_second" ] += value
620845 elif key == "generation_time" :
@@ -628,8 +853,13 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
628853 aggregated_results [key ].extend (value )
629854
630855 num_results = len (results )
631- aggregated_results ["avg_tokens_per_second" ] = aggregated_results ["total_tokens_per_second" ] / num_results
856+ aggregated_results ["avg_tokens_per_second" ] = (
857+ aggregated_results ["total_num_new_tokens" ] / aggregated_results ["total_generation_time" ]
858+ if aggregated_results ["total_generation_time" ] > 0
859+ else 0
860+ )
632861 aggregated_results ["avg_mfu" ] = aggregated_results ["total_mfu" ] / num_results
862+ aggregated_results ["avg_mbu" ] = aggregated_results ["total_mbu" ] / num_results
633863 aggregated_results ["avg_generation_time" ] = aggregated_results ["total_generation_time" ] / num_results
634864 return aggregated_results
635865
@@ -659,6 +889,7 @@ def print_summary(
659889 print (f"Average results over { len (results )} main benchmark batches:" )
660890 print (f"Average tokens/second: { agg_results ['avg_tokens_per_second' ]:.2f} " )
661891 print (f"Average MFU: { agg_results ['avg_mfu' ]:.2f} %" )
892+ print (f"Average MBU: { agg_results ['avg_mbu' ]:.2f} %" )
662893 print (f"Average generation time per batch: { agg_results ['avg_generation_time' ]:.2f} s" )
663894 print (f"Average new tokens per sample: { avg_new_tokens_per_sample :.2f} tokens" )
664895
@@ -673,6 +904,7 @@ def print_summary(
673904 print (f"GPU device: { torch .cuda .get_device_name (0 )} " )
674905 print (f"GPU peak FLOPs: { gpu_specs ['flops' ] / 1e12 :.0f} TFLOPs" )
675906 print (f"GPU memory size: { gpu_specs ['memory_size' ] / 1e9 :.0f} GB" )
907+ print (f"GPU memory bandwidth: { gpu_specs ['memory_bandwidth' ] / 1e12 :.2f} TB/s" )
676908
677909 print ("-" * 60 )
678910 print ("COMPLETION LENGTH STATISTICS:" )
0 commit comments