Skip to content
Open
Show file tree
Hide file tree
Changes from 187 commits
Commits
Show all changes
192 commits
Select commit Hold shift + click to select a range
b133e2a
Cleaned up evals to have same names as training data.
finbarrtimbers Jul 29, 2025
1265789
Refactored evals to use a batch.
finbarrtimbers Jul 29, 2025
aa76132
Now, we accumulate eval results.
finbarrtimbers Jul 29, 2025
4e30841
Merge branch 'main' into fix-eval
finbarrtimbers Jul 29, 2025
66af972
Updated scripts so they run.
finbarrtimbers Jul 29, 2025
cfa55c9
More refactoring.
finbarrtimbers Jul 29, 2025
433242a
Now, use the minimum of the number of requested samples and the actua…
finbarrtimbers Jul 29, 2025
0836fca
Ran linter, and fixed extra arg issue.
finbarrtimbers Jul 29, 2025
8028a31
Always insert into pending_queries_map.
finbarrtimbers Jul 29, 2025
9816f34
Update signature in eval.
finbarrtimbers Jul 29, 2025
97b8de9
Merge branch 'main' into fix-eval
finbarrtimbers Jul 29, 2025
e862a14
Another attempted fix.
finbarrtimbers Jul 29, 2025
9676db3
Ran linter.
finbarrtimbers Jul 29, 2025
d044278
Now, eval requests use the eval params, and normal ones use the norma…
finbarrtimbers Jul 29, 2025
6a694bf
Now, tests should pass.
finbarrtimbers Jul 29, 2025
f45b951
Merge branch 'main' into fix-eval
finbarrtimbers Jul 29, 2025
96df985
Remove simple config and pass generation_config through.
finbarrtimbers Jul 30, 2025
b931a35
Now, generation config is passed through.
finbarrtimbers Jul 30, 2025
aa0facb
Ran linter.
finbarrtimbers Jul 30, 2025
9dd0711
Ran linter.
finbarrtimbers Jul 30, 2025
cbf7aa7
Added a while loop.
finbarrtimbers Jul 30, 2025
84b9a4c
Added a while loop with retries.
finbarrtimbers Jul 30, 2025
93c0a97
Merge branch 'main' into fix-eval
finbarrtimbers Jul 30, 2025
87aa0fa
Added logs.
finbarrtimbers Jul 30, 2025
b636127
Fix queue issue.
finbarrtimbers Jul 30, 2025
d0f8870
Add progress bars to all ray.get calls.
finbarrtimbers Jul 30, 2025
9f9e644
Merge branch 'main' into fix-eval
finbarrtimbers Jul 30, 2025
08de6ea
Cleaned up some of the logging.
finbarrtimbers Jul 30, 2025
634e1fb
Changed how we handle full queues.
finbarrtimbers Jul 30, 2025
ada6556
Ran linter.
finbarrtimbers Jul 30, 2025
c29d1d0
Clean up for PR.
finbarrtimbers Jul 30, 2025
95960bc
Switched LLMRayActor to use LLMEngine.
finbarrtimbers Jul 29, 2025
4be2693
Fixes expected output.
finbarrtimbers Jul 29, 2025
d2c1db7
Keep backwards compatibility for tool use.
finbarrtimbers Jul 29, 2025
c1fdd90
Remove manual reorganization.
finbarrtimbers Jul 29, 2025
45791df
Cleaned up implementation.
finbarrtimbers Jul 29, 2025
ad45c6a
Now, we use a generation loop.
finbarrtimbers Jul 29, 2025
5e1c2a6
Uses an ActorManager to manage weight updates.
finbarrtimbers Jul 30, 2025
c13f951
Cleaned up code to use actor manager.
finbarrtimbers Jul 30, 2025
c4dde78
Now, tests pass.
finbarrtimbers Jul 30, 2025
b0fd4f2
Fixed error when calling process_from_queue.
finbarrtimbers Jul 31, 2025
d93d4a8
Added ActorManager.
finbarrtimbers Jul 31, 2025
a444508
Added a test for the actor manager.
finbarrtimbers Jul 31, 2025
10bc07a
Tests pass. Fixed another issue.
finbarrtimbers Jul 31, 2025
4805d46
Ran linter.
finbarrtimbers Jul 31, 2025
30bdce2
Added better error handling.
finbarrtimbers Jul 31, 2025
28ceca9
Potential fix to hanging forever issue.
finbarrtimbers Jul 31, 2025
b1053b8
Another attempt to fix the deadlock.
finbarrtimbers Jul 31, 2025
a7be9bf
Fix code so that it no longer expects process_from_queue to return a …
finbarrtimbers Jul 31, 2025
8343f92
Fixed cleanup code.
finbarrtimbers Jul 31, 2025
54357cd
Fixed issue; now should exit.
finbarrtimbers Jul 31, 2025
27198ef
Added test scripts.
finbarrtimbers Aug 1, 2025
ed467aa
Break out requests into N separate ones.
finbarrtimbers Aug 1, 2025
1276c0f
Found why LLMEngine behaviour differs from LLM. Fixed issue.
finbarrtimbers Aug 1, 2025
4811ca4
Code runs now.
finbarrtimbers Aug 1, 2025
4d7eb5f
First implementation of streaming individual prompts.
finbarrtimbers Aug 1, 2025
a38bfcc
Now, the default for inference_batch_size is to evenly divide them ac…
finbarrtimbers Aug 1, 2025
5428399
Ran linter.
finbarrtimbers Aug 1, 2025
7f2297f
Tests pass.
finbarrtimbers Aug 6, 2025
8590c8a
Merge branch 'main' into continual-processing
finbarrtimbers Aug 7, 2025
d3dae24
Now, tests pass.
finbarrtimbers Aug 7, 2025
dc8c2a7
Removed debugging code.
finbarrtimbers Aug 7, 2025
b34eb61
Fixed function signature calls.
finbarrtimbers Aug 7, 2025
2ff2cca
Tests pass. CLeaned up loop.
finbarrtimbers Aug 7, 2025
3eb9a3f
Now, use inference_batch_size.
finbarrtimbers Aug 7, 2025
7352c61
Updated script.
finbarrtimbers Aug 8, 2025
419abf9
Merge branch 'main' into continual-processing
finbarrtimbers Aug 18, 2025
ab5266f
Linter passes.
finbarrtimbers Aug 18, 2025
2005abc
Clean up for PR.
finbarrtimbers Aug 18, 2025
9897ce0
Implementing individual propmt processing.
finbarrtimbers Aug 18, 2025
49c9963
Rran linter.
finbarrtimbers Aug 18, 2025
1e0e1dd
Updated tests so they pass.
finbarrtimbers Aug 19, 2025
442c389
Merge branch 'main' into continual-processing
finbarrtimbers Aug 19, 2025
43efddd
Update Dockerfile
finbarrtimbers Aug 19, 2025
6aafbfd
Merge branch 'main' into continual-processing
finbarrtimbers Aug 19, 2025
5d15367
Merge branch 'main' into continual-processing
finbarrtimbers Aug 19, 2025
041eed8
Fixed signature.
finbarrtimbers Aug 19, 2025
aee2985
Added test script.
finbarrtimbers Aug 19, 2025
3ae5983
Cleane dup local thread.
finbarrtimbers Aug 19, 2025
002ee30
Now, test runs.
finbarrtimbers Aug 19, 2025
8aca59a
Let's see the code
finbarrtimbers Aug 19, 2025
5d7bf7c
Fixes.
finbarrtimbers Aug 19, 2025
508bb63
Added logging
finbarrtimbers Aug 19, 2025
a9c28f1
Fixed the way we accumulate.
finbarrtimbers Aug 19, 2025
bca5574
Update request prefix.
finbarrtimbers Aug 20, 2025
5ee04b2
Added verbose flag to test scripts.
finbarrtimbers Aug 20, 2025
87d9cbf
Merge branch 'main' into continual-processing
finbarrtimbers Aug 20, 2025
1c3e013
Updated code
finbarrtimbers Aug 20, 2025
546d4d6
Added single gpu test script.
finbarrtimbers Aug 20, 2025
2a3619a
Removed logging.
finbarrtimbers Aug 20, 2025
b831ef9
clean up.
finbarrtimbers Aug 20, 2025
de279d5
Remove debugging.
finbarrtimbers Aug 20, 2025
f45fb2f
Removed debugging logs.
finbarrtimbers Aug 20, 2025
364f14e
Removed debugging code.
finbarrtimbers Aug 20, 2025
165ab20
Merge branch 'main' into continual-processing
finbarrtimbers Aug 20, 2025
5e020e2
Updated code.
finbarrtimbers Aug 20, 2025
4db225d
Added inflight updates.
finbarrtimbers Aug 20, 2025
f4b3144
Update code
finbarrtimbers Aug 20, 2025
fa57ffe
Updated code
finbarrtimbers Aug 20, 2025
43ea247
Update code.
finbarrtimbers Aug 20, 2025
1e74539
Updated code.
finbarrtimbers Aug 20, 2025
fb8e286
Updated code.
finbarrtimbers Aug 20, 2025
cd3277a
Clean up
finbarrtimbers Aug 20, 2025
d2e2f57
More logging
finbarrtimbers Aug 20, 2025
466eacc
Ran linter. Added more logging.
finbarrtimbers Aug 20, 2025
47e34c2
Update
finbarrtimbers Aug 20, 2025
a682b50
Updated code.
finbarrtimbers Aug 21, 2025
c7934ee
Updated logging.
finbarrtimbers Aug 21, 2025
28b5c76
More logging.
finbarrtimbers Aug 21, 2025
68c26a9
Updated logging.
finbarrtimbers Aug 21, 2025
19cbbc2
Merge branch 'main' into continual-processing
finbarrtimbers Aug 21, 2025
f5d3ca6
Longer sleep.
finbarrtimbers Aug 21, 2025
35920ca
Update logging.
finbarrtimbers Aug 21, 2025
9097214
Fixes eval mismatch.
finbarrtimbers Aug 21, 2025
c177840
Ran linter.
finbarrtimbers Aug 21, 2025
9c3448a
Updated code.
finbarrtimbers Aug 21, 2025
af784d0
Merge branch 'main' into continual-processing
finbarrtimbers Aug 21, 2025
9eba1d0
Added comment.
finbarrtimbers Aug 21, 2025
1f381c5
Handled the tool use case.
finbarrtimbers Aug 21, 2025
55b5aa0
Added local test.
finbarrtimbers Aug 21, 2025
974b3ee
Updated logging code.
finbarrtimbers Aug 21, 2025
d4bf6b7
Removed tracking from script.
finbarrtimbers Aug 21, 2025
79ac2a0
Now, benchmark runs.
finbarrtimbers Aug 21, 2025
4829ff8
Updated benchmark code.
finbarrtimbers Aug 21, 2025
0187d01
Updated time.
finbarrtimbers Aug 21, 2025
e2a0d67
Updated tests.
finbarrtimbers Aug 21, 2025
bb8ae89
moidified script.
finbarrtimbers Aug 21, 2025
62336e4
Update script.
finbarrtimbers Aug 21, 2025
ed99644
Fixed code.
finbarrtimbers Aug 21, 2025
3362560
Refactored code.
finbarrtimbers Aug 21, 2025
95b99b7
Updated code.
finbarrtimbers Aug 21, 2025
1feee6b
Added logging.
finbarrtimbers Aug 21, 2025
d6cf4a5
Added test.
finbarrtimbers Aug 21, 2025
a8a5890
Merge branch 'main' into continual-processing
finbarrtimbers Aug 21, 2025
48a8b47
Updated code.
finbarrtimbers Aug 22, 2025
6c2d748
Fixed logger issue.
finbarrtimbers Aug 22, 2025
d4849bc
Fixed bug.
finbarrtimbers Aug 22, 2025
aa754bf
Cleaned up code.
finbarrtimbers Aug 22, 2025
de9f1e2
Merge branch 'main' into continual-processing
finbarrtimbers Aug 22, 2025
8e15206
Chnagd clsuter
finbarrtimbers Aug 22, 2025
0fdf116
Remove weka from gantry
finbarrtimbers Aug 22, 2025
28a53a8
Updated code
finbarrtimbers Aug 22, 2025
b4ac5f8
Moved logging statement.
finbarrtimbers Aug 22, 2025
e192a0b
Cleaned up code.
finbarrtimbers Aug 22, 2025
bbde95f
udpate code
finbarrtimbers Aug 22, 2025
a383a7b
Update code
finbarrtimbers Aug 22, 2025
6900ea5
Clean up code
finbarrtimbers Aug 22, 2025
b621db9
Update code
finbarrtimbers Aug 22, 2025
3ccf83d
Updated code
finbarrtimbers Aug 22, 2025
bd2018c
Fixed script.
finbarrtimbers Aug 22, 2025
ade8ccd
Updated logging.
finbarrtimbers Aug 22, 2025
55018e8
Clean up benchmark script changes.
finbarrtimbers Aug 22, 2025
425d6aa
Merge branch 'main' into continual-processing
saurabh111233212 Aug 22, 2025
9431815
Merge branch 'main' into continual-processing
finbarrtimbers Aug 22, 2025
a36ea2e
Added logging
finbarrtimbers Aug 22, 2025
197dc3d
Updated var.
finbarrtimbers Aug 22, 2025
e64df5b
Update size
finbarrtimbers Aug 23, 2025
a97b01b
Updated scripts to have inflight_updates set.
finbarrtimbers Aug 25, 2025
10f99ae
Removed cache reset.
finbarrtimbers Aug 26, 2025
3ce242c
Cleaned up logging
finbarrtimbers Aug 26, 2025
43ef2d8
Reproed speed issue
finbarrtimbers Aug 26, 2025
5ba2dde
Added is_beaker-job back
finbarrtimbers Aug 26, 2025
9dfb26c
Cleaned up code
finbarrtimbers Aug 22, 2025
a409919
Removed tokenizer
finbarrtimbers Aug 26, 2025
cc88121
Updated script.
finbarrtimbers Aug 26, 2025
89f2020
New large test script
finbarrtimbers Aug 26, 2025
fdbac24
Addded script
finbarrtimbers Aug 26, 2025
4a0b864
Merge branch 'main' into continual-processing
finbarrtimbers Aug 27, 2025
6045914
Benchmark fixes.
finbarrtimbers Aug 27, 2025
0baee68
Fixed merge issues.
finbarrtimbers Aug 27, 2025
0ef100f
Fixed benchmark.
finbarrtimbers Aug 27, 2025
99848ff
Added metrics.
finbarrtimbers Aug 27, 2025
ba7a0d6
Merge branch 'main' into continual-processing
finbarrtimbers Aug 27, 2025
cc09d4f
Removed debug logging.
finbarrtimbers Aug 27, 2025
4c99dc8
Cleaned up PR.
finbarrtimbers Aug 27, 2025
c5ce3ef
Cleaned up PR.
finbarrtimbers Aug 27, 2025
4c13c00
Cleaned up benchmark.
finbarrtimbers Aug 27, 2025
d96bc28
Cleaned up PR.
finbarrtimbers Aug 27, 2025
e9d8484
Merge branch 'main' into continual-processing
finbarrtimbers Aug 27, 2025
0b9a1fc
Enable inflight weight updates (#955)
finbarrtimbers Aug 27, 2025
1c59533
Ran linter.
finbarrtimbers Aug 27, 2025
f4e8e9c
Removed warning.
finbarrtimbers Aug 27, 2025
1ef2948
Removed debugging code.
finbarrtimbers Aug 27, 2025
6107ad0
Cleaned up benchmark code.
finbarrtimbers Aug 27, 2025
7bb2548
Added loggign
finbarrtimbers Aug 27, 2025
b6acc05
Ran linter.
finbarrtimbers Aug 27, 2025
a960545
Changes
finbarrtimbers Aug 27, 2025
2cfbc0f
modified loop to not start until we have prompts in the queue
finbarrtimbers Aug 29, 2025
d5d9013
Merge branch 'main' into continual-processing
finbarrtimbers Aug 29, 2025
d3d3a3d
Merge branch 'main' into continual-processing
finbarrtimbers Aug 29, 2025
e26b45c
Merge branch 'main' into continual-processing
finbarrtimbers Aug 30, 2025
a28585b
Changed batch size back
finbarrtimbers Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 133 additions & 75 deletions open_instruct/benchmark_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,27 @@ def get_git_commit() -> str:


def save_benchmark_results_to_csv(
results: list[dict[str, Any]], total_time: float, args: grpo_fast.Args, model_config: model_utils.ModelConfig
results: list[dict[str, Any]],
total_time: float,
overall_tokens_per_second: float,
args: grpo_fast.Args,
model_config: model_utils.ModelConfig,
) -> None:
"""Save benchmark results to CSV file."""
git_commit = get_git_commit()
agg_results = aggregate_results(results)
csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv"

device_name = get_device_name(torch.cuda.get_device_name(0))
device_flops = GPU_SPECS[device_name]["flops"]
model_dims = load_model_dims(model_config.model_name_or_path)
all_prompt_lengths = [result["prompt_lengths"] for result in results]
all_response_lengths = [result["response_lengths"] for result in results]

total_flops = model_dims.flops(
all_prompt_lengths, all_response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
)

row_data = {
"git_commit": git_commit,
"model": model_config.model_name_or_path,
Expand All @@ -149,6 +163,7 @@ def save_benchmark_results_to_csv(
"avg_generation_time_per_batch": agg_results["avg_generation_time"],
"avg_new_tokens_per_sample": agg_results["total_num_new_tokens"]
/ (len(results) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout),
"overall_mfu": 100 * (total_flops / total_time) / device_flops,
}

csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv"
Expand Down Expand Up @@ -379,7 +394,7 @@ def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation


def setup_vllm_engines(
args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480
args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480, num_batches: int = 5
) -> tuple[list[ray.actor.ActorHandle], ray_queue.Queue, ray_queue.Queue]:
"""Set up vLLM engines and queues."""
logger.info("Setting up vLLM engines...")
Expand All @@ -393,10 +408,14 @@ def setup_vllm_engines(
pg = ray.util.placement_group(bundles, strategy="PACK")
ray.get(pg.ready())

param_prompt_Q = ray_queue.Queue(maxsize=10)
inference_results_Q = ray_queue.Queue(maxsize=10)
# Queue size needs to accommodate all individual prompts across all batches.
# Each batch has num_unique_prompts_rollout prompts, and we submit them individually.
# Total individual prompts = num_unique_prompts_rollout * num_batches
queue_size = args.num_unique_prompts_rollout * num_batches
param_prompt_Q = ray_queue.Queue(maxsize=queue_size)
inference_results_Q = ray_queue.Queue(maxsize=queue_size)

actor_manager = vllm_utils3.ActorManager.remote()
actor_manager = ray.remote(vllm_utils3.ActorManager).remote()

vllm_engines = vllm_utils3.create_vllm_engines(
num_engines=args.vllm_num_engines,
Expand All @@ -410,6 +429,7 @@ def setup_vllm_engines(
max_model_len=max_model_len,
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
single_gpu_mode=False,
inference_batch_size=args.inference_batch_size,
pg=pg,
tools={},
max_tool_calls=[0],
Expand All @@ -427,7 +447,11 @@ def generate_thread(vllm_engines: list[ray.actor.ActorHandle], stop_event: threa
"""Thread that repeatedly calls process_from_queue on vllm engines."""
logger.info("[Generate Thread] Starting generation thread")
while not stop_event.is_set():
processed_results = ray.get([engine.process_from_queue.remote(timeout=20) for engine in vllm_engines])
processed_results = utils.ray_get_with_progress(
[engine.process_from_queue.remote(timeout=20) for engine in vllm_engines],
"[Generate Thread] Waiting for engines to process prompts",
enable=False,
)
num_processed = sum(int(result) for result in processed_results)
if num_processed == 0:
time.sleep(1)
Expand All @@ -444,31 +468,30 @@ def submission_thread(
start_batch_idx: int,
num_batches: int,
) -> None:
"""Thread that submits prompts to the queue."""
"""Thread that submits individual prompts to the queue."""
logger.info("[Submission Thread] Starting prompt submission")
for batch_idx in range(start_batch_idx, start_batch_idx + num_batches):
if stop_event.is_set():
logger.info("[Submission Thread] Stopped due to stop event")
break

# Get batch data from dataset
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(dataset))
batch_data = dataset[start_idx:end_idx]
prompts = batch_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]

# Create list of dataset indices for this batch
dataset_indices = list(range(start_idx, end_idx))

param_prompt_Q.put(
PromptRequest(
prompts=prompts,
dataset_index=dataset_indices,
generation_config=generation_config,
start_time=time.perf_counter(),
total_prompts_submitted = 0
for batch_idx in range(num_batches):
for prompt_idx in range(batch_size):
dataset_idx = start_batch_idx * batch_size + batch_idx * batch_size + prompt_idx
if dataset_idx >= len(dataset):
break

prompt = dataset[dataset_idx][dataset_transformation.INPUT_IDS_PROMPT_KEY]
actual_batch_idx = start_batch_idx + batch_idx
actual_dataset_index = dataset_idx
param_prompt_Q.put(
PromptRequest(
prompt=prompt,
generation_config=generation_config,
training_step=actual_batch_idx,
dataset_index=actual_dataset_index,
is_eval=False,
start_time=time.perf_counter(),
)
)
)
logger.info(f"[Submission Thread] All {num_batches} batches submitted")
total_prompts_submitted += 1
logger.info(f"[Submission Thread] All {total_prompts_submitted} individual prompts submitted")


def run_benchmark(
Expand Down Expand Up @@ -500,36 +523,36 @@ def run_benchmark(
)

stop_event = threading.Event()
executor = futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="benchmark")
executor = futures.ThreadPoolExecutor(max_workers=len(vllm_engines) + 1, thread_name_prefix="benchmark")

generation_future = executor.submit(generate_thread, vllm_engines, stop_event)
submission_future = None # Initialize to None for access in finally block

results = []
device_name = get_device_name(torch.cuda.get_device_name(0))
device_flops = GPU_SPECS[device_name]["flops"]

# Submit warmup batch first
# Load model dims for FLOPS calculations
model_dims = load_model_dims(model_config.model_name_or_path)

logger.info("Submitting warmup batch...")
warmup_start_idx = 0
warmup_end_idx = min(args.num_unique_prompts_rollout, len(dataset))
warmup_data = dataset[warmup_start_idx:warmup_end_idx]
warmup_prompts = warmup_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]
warmup_dataset_indices = list(range(warmup_start_idx, warmup_end_idx))
param_prompt_Q.put(
PromptRequest(
prompts=warmup_prompts,
dataset_index=warmup_dataset_indices,
generation_config=generation_config,
start_time=time.perf_counter(),
for prompt_idx in range(args.num_unique_prompts_rollout):
param_prompt_Q.put(
PromptRequest(
prompt=dataset[prompt_idx][dataset_transformation.INPUT_IDS_PROMPT_KEY],
generation_config=generation_config,
training_step=0, # warmup is training step 0
dataset_index=prompt_idx,
is_eval=False,
)
)
)
model_dims = load_model_dims(model_config.model_name_or_path)

try:
logger.info("Running warmup batch...")

warmup_result = inference_results_Q.get()
logger.info(f"Warmup batch completed with {len(warmup_result.responses)} responses")
for _ in range(args.num_unique_prompts_rollout):
inference_results_Q.get()
logger.info("Warmup batch completed")
logger.info(f"Submitting {num_batches - 1} batches for main benchmark...")
submission_future = executor.submit(
submission_thread,
Expand All @@ -541,41 +564,51 @@ def run_benchmark(
1,
num_batches - 1,
)
# Process remaining batches with timing

main_benchmark_start_time = time.perf_counter()

for batch_idx in range(1, num_batches):
# Quick health check!
[future.result() for future in [submission_future, generation_future] if future.done()]
result = inference_results_Q.get()

batch_results = []
batch_start_time = time.perf_counter()
for _ in range(args.num_unique_prompts_rollout):
batch_results.append(inference_results_Q.get())

completion_time = time.perf_counter()
# Calculate generation time from when the request was enqueued
batch_generation_time = completion_time - result.start_time if result.start_time else 0
batch_generation_time = completion_time - batch_start_time

total_new_tokens = 0
all_response_lengths = []
all_finish_reasons = []

new_tokens = sum(len(response) for response in result.responses)
tokens_per_second = new_tokens / batch_generation_time if batch_generation_time > 0 else 0
for i, result in enumerate(batch_results):
result_tokens = sum(len(response) for response in result.responses)
total_new_tokens += result_tokens
all_response_lengths.extend([len(response) for response in result.responses])
all_finish_reasons.extend(result.finish_reasons)

tokens_per_second = total_new_tokens / batch_generation_time if batch_generation_time > 0 else 0
result_dict = {
"tokens_per_second": tokens_per_second,
"generation_time": batch_generation_time,
"num_new_tokens": new_tokens,
"finish_reasons": collections.Counter(result.finish_reasons),
"response_lengths": [len(response) for response in result.responses],
"dataset_indices": result.dataset_index,
"num_new_tokens": total_new_tokens,
"finish_reasons": collections.Counter(all_finish_reasons),
"response_lengths": all_response_lengths,
"dataset_indices": [r.dataset_index for r in batch_results],
}
# Get prompt lengths using dataset indices from the result
prompt_data = dataset[result.dataset_index]
prompts = prompt_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]
prompt_lengths = [len(prompt) for prompt in prompts]
response_lengths = [len(response) for response in result.responses]

# Calculate total FLOPs for all prompts and responses in the batch
# No need to expand prompt_lengths - the flops method now handles samples_per_prompt
prompt_lengths = []
response_lengths = all_response_lengths
for r in batch_results:
prompt_data = dataset[r.dataset_index]
prompt = prompt_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]
prompt_lengths.append(len(prompt))

result_dict["prompt_lengths"] = prompt_lengths
model_flops = model_dims.flops(
prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
)

# MFU = (FLOPs / time) / peak_FLOPS * 100
model_flops_per_second = model_flops / batch_generation_time if batch_generation_time > 0 else 0
model_flops_per_second = model_flops / batch_generation_time
result_dict["mfu"] = 100 * model_flops_per_second / device_flops

save_completion_lengths([result_dict], timestamp, batch_idx)
Expand All @@ -585,18 +618,36 @@ def run_benchmark(
f"{result_dict['tokens_per_second']:.2f} new tokens/sec, "
f"MFU: {result_dict['mfu']:.2f}%, "
f"generation time: {batch_generation_time:.2f}s, "
f"total new tokens: {new_tokens}"
f"total new tokens: {total_new_tokens}"
)

# Calculate total time for main benchmark only
main_benchmark_time = sum(r["generation_time"] for r in results)
main_benchmark_end_time = time.time()
main_benchmark_total_time = main_benchmark_end_time - main_benchmark_start_time

total_main_tokens = sum(r["num_new_tokens"] for r in results)
overall_tokens_per_second = (
total_main_tokens / main_benchmark_total_time if main_benchmark_total_time > 0 else 0
)

print_summary(results, main_benchmark_time, args, model_config)
save_benchmark_results_to_csv(results, main_benchmark_time, args, model_config)
logger.info("\nOverall main benchmark performance:")
logger.info(f" Total wall-clock time: {main_benchmark_total_time:.2f}s")
logger.info(f" Total tokens generated: {total_main_tokens}")
logger.info(f" Overall tokens/second: {overall_tokens_per_second:.2f}")

print_summary(results, main_benchmark_total_time, overall_tokens_per_second, args, model_config)
save_benchmark_results_to_csv(
results, main_benchmark_total_time, overall_tokens_per_second, args, model_config
)
finally:
logger.info("Starting cleanup...")
stop_event.set()
executor.shutdown(wait=True)

logger.info("Waiting for threads to complete...")
try:
submission_future.result(timeout=5)
generation_future.result(timeout=10)
finally:
executor.shutdown(wait=True)
logger.info("Threads cleaned up")


Expand Down Expand Up @@ -635,7 +686,11 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:


def print_summary(
results: list[dict[str, Any]], total_time: float, args: grpo_fast.Args, model_config: model_utils.ModelConfig
results: list[dict[str, Any]],
total_time: float,
overall_tokens_per_second: float,
args: grpo_fast.Args,
model_config: model_utils.ModelConfig,
) -> None:
"""Print benchmark summary statistics."""

Expand All @@ -653,8 +708,11 @@ def print_summary(
print(f"Num rollouts: {args.num_samples_per_prompt_rollout}")
print(f"Max tokens: {args.response_length}")
print("-" * 60)
print(f"Total wall-clock time (main benchmark): {total_time:.2f}s")
print(f"Total new tokens generated: {agg_results['total_num_new_tokens']}")
print(f"Total time (main benchmark): {agg_results['total_generation_time']:.2f}s")
print(f"Total new tokens generated: {agg_results['total_num_new_tokens']}")
print(f"Overall tokens/second: {overall_tokens_per_second:.2f}")
print("-" * 60)
print(f"Average results over {len(results)} main benchmark batches:")
print(f"Average tokens/second: {agg_results['avg_tokens_per_second']:.2f}")
Expand Down
Loading