diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py
index c81ec9996..ec94e4cc2 100644
--- a/open_instruct/benchmark_generators.py
+++ b/open_instruct/benchmark_generators.py
@@ -127,13 +127,27 @@ def get_git_commit() -> str:
def save_benchmark_results_to_csv(
- results: list[dict[str, Any]], total_time: float, args: grpo_fast.Args, model_config: model_utils.ModelConfig
+ results: list[dict[str, Any]],
+ total_time: float,
+ overall_tokens_per_second: float,
+ args: grpo_fast.Args,
+ model_config: model_utils.ModelConfig,
) -> None:
"""Save benchmark results to CSV file."""
git_commit = get_git_commit()
agg_results = aggregate_results(results)
csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv"
+ device_name = get_device_name(torch.cuda.get_device_name(0))
+ device_flops = GPU_SPECS[device_name]["flops"]
+ model_dims = load_model_dims(model_config.model_name_or_path)
+ all_prompt_lengths = [result["prompt_lengths"] for result in results]
+ all_response_lengths = [result["response_lengths"] for result in results]
+
+ total_flops = model_dims.flops(
+ all_prompt_lengths, all_response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
+ )
+
row_data = {
"git_commit": git_commit,
"model": model_config.model_name_or_path,
@@ -152,6 +166,7 @@ def save_benchmark_results_to_csv(
"avg_generation_time_per_batch": agg_results["avg_generation_time"],
"avg_new_tokens_per_sample": agg_results["total_num_new_tokens"]
/ (len(results) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout),
+ "overall_mfu": 100 * (total_flops / total_time) / device_flops,
}
csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv"
@@ -591,7 +606,7 @@ def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation
def setup_vllm_engines(
- args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480
+ args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480, num_batches: int = 5
) -> tuple[list[ray.actor.ActorHandle], ray_queue.Queue, ray_queue.Queue]:
"""Set up vLLM engines and queues."""
logger.info("Setting up vLLM engines...")
@@ -605,8 +620,12 @@ def setup_vllm_engines(
pg = ray.util.placement_group(bundles, strategy="PACK")
ray.get(pg.ready())
- param_prompt_Q = ray_queue.Queue(maxsize=10)
- inference_results_Q = ray_queue.Queue(maxsize=10)
+ # Queue size needs to accommodate all individual prompts across all batches.
+ # Each batch has num_unique_prompts_rollout prompts, and we submit them individually.
+ # Total individual prompts = num_unique_prompts_rollout * num_batches
+ queue_size = args.num_unique_prompts_rollout * num_batches
+ param_prompt_Q = ray_queue.Queue(maxsize=queue_size)
+ inference_results_Q = ray_queue.Queue(maxsize=queue_size)
queues_to_monitor = {"Param Prompt Queue": param_prompt_Q, "Inference Results Queue": inference_results_Q}
actor_manager = ray.remote(ActorManager).remote(queues_to_monitor, args)
@@ -623,6 +642,7 @@ def setup_vllm_engines(
max_model_len=max_model_len,
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
single_gpu_mode=False,
+ inference_batch_size=args.inference_batch_size,
pg=pg,
tools={},
max_tool_calls=[0],
@@ -640,7 +660,11 @@ def generate_thread(vllm_engines: list[ray.actor.ActorHandle], stop_event: threa
"""Thread that repeatedly calls process_from_queue on vllm engines."""
logger.info("[Generate Thread] Starting generation thread")
while not stop_event.is_set():
- processed_results = ray.get([engine.process_from_queue.remote(timeout=20) for engine in vllm_engines])
+ processed_results = utils.ray_get_with_progress(
+ [engine.process_from_queue.remote(timeout=20) for engine in vllm_engines],
+ "[Generate Thread] Waiting for engines to process prompts",
+ enable=False,
+ )
num_processed = sum(int(result) for result in processed_results)
if num_processed == 0:
time.sleep(1)
@@ -657,31 +681,30 @@ def submission_thread(
start_batch_idx: int,
num_batches: int,
) -> None:
- """Thread that submits prompts to the queue."""
+ """Thread that submits individual prompts to the queue."""
logger.info("[Submission Thread] Starting prompt submission")
- for batch_idx in range(start_batch_idx, start_batch_idx + num_batches):
- if stop_event.is_set():
- logger.info("[Submission Thread] Stopped due to stop event")
- break
-
- # Get batch data from dataset
- start_idx = batch_idx * batch_size
- end_idx = min(start_idx + batch_size, len(dataset))
- batch_data = dataset[start_idx:end_idx]
- prompts = batch_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]
-
- # Create list of dataset indices for this batch
- dataset_indices = list(range(start_idx, end_idx))
-
- param_prompt_Q.put(
- PromptRequest(
- prompts=prompts,
- dataset_index=dataset_indices,
- generation_config=generation_config,
- start_time=time.perf_counter(),
+ total_prompts_submitted = 0
+ for batch_idx in range(num_batches):
+ for prompt_idx in range(batch_size):
+ dataset_idx = start_batch_idx * batch_size + batch_idx * batch_size + prompt_idx
+ if dataset_idx >= len(dataset):
+ break
+
+ prompt = dataset[dataset_idx][dataset_transformation.INPUT_IDS_PROMPT_KEY]
+ actual_batch_idx = start_batch_idx + batch_idx
+ actual_dataset_index = dataset_idx
+ param_prompt_Q.put(
+ PromptRequest(
+ prompt=prompt,
+ generation_config=generation_config,
+ training_step=actual_batch_idx,
+ dataset_index=actual_dataset_index,
+ is_eval=False,
+ start_time=time.perf_counter(),
+ )
)
- )
- logger.info(f"[Submission Thread] All {num_batches} batches submitted")
+ total_prompts_submitted += 1
+ logger.info(f"[Submission Thread] All {total_prompts_submitted} individual prompts submitted")
def run_benchmark(
@@ -713,37 +736,37 @@ def run_benchmark(
)
stop_event = threading.Event()
- executor = futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="benchmark")
+ executor = futures.ThreadPoolExecutor(max_workers=len(vllm_engines) + 1, thread_name_prefix="benchmark")
generation_future = executor.submit(generate_thread, vllm_engines, stop_event)
+ submission_future = None # Initialize to None for access in finally block
results = []
device_name = get_device_name(torch.cuda.get_device_name(0))
device_flops = GPU_SPECS[device_name]["flops"]
device_memory_bandwidth = GPU_SPECS[device_name]["memory_bandwidth"]
- # Submit warmup batch first
+ # Load model dims for FLOPS calculations
+ model_dims = load_model_dims(model_config.model_name_or_path)
+
logger.info("Submitting warmup batch...")
- warmup_start_idx = 0
- warmup_end_idx = min(args.num_unique_prompts_rollout, len(dataset))
- warmup_data = dataset[warmup_start_idx:warmup_end_idx]
- warmup_prompts = warmup_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]
- warmup_dataset_indices = list(range(warmup_start_idx, warmup_end_idx))
- param_prompt_Q.put(
- PromptRequest(
- prompts=warmup_prompts,
- dataset_index=warmup_dataset_indices,
- generation_config=generation_config,
- start_time=time.perf_counter(),
+ for prompt_idx in range(args.num_unique_prompts_rollout):
+ param_prompt_Q.put(
+ PromptRequest(
+ prompt=dataset[prompt_idx][dataset_transformation.INPUT_IDS_PROMPT_KEY],
+ generation_config=generation_config,
+ training_step=0, # warmup is training step 0
+ dataset_index=prompt_idx,
+ is_eval=False,
+ )
)
- )
- model_dims = load_model_dims(model_config.model_name_or_path)
try:
logger.info("Running warmup batch...")
- warmup_result = inference_results_Q.get()
- logger.info(f"Warmup batch completed with {len(warmup_result.responses)} responses")
+ for _ in range(args.num_unique_prompts_rollout):
+ inference_results_Q.get()
+ logger.info("Warmup batch completed")
logger.info(f"Submitting {num_batches - 1} batches for main benchmark...")
submission_future = executor.submit(
submission_thread,
@@ -755,41 +778,51 @@ def run_benchmark(
1,
num_batches - 1,
)
- # Process remaining batches with timing
+
+ main_benchmark_start_time = time.perf_counter()
+
for batch_idx in range(1, num_batches):
- # Quick health check!
[future.result() for future in [submission_future, generation_future] if future.done()]
- result = inference_results_Q.get()
+
+ batch_results = []
+ batch_start_time = time.perf_counter()
+ for _ in range(args.num_unique_prompts_rollout):
+ batch_results.append(inference_results_Q.get())
completion_time = time.perf_counter()
- # Calculate generation time from when the request was enqueued
- batch_generation_time = completion_time - result.start_time if result.start_time else 0
+ batch_generation_time = completion_time - batch_start_time
+
+ total_new_tokens = 0
+ all_response_lengths = []
+ all_finish_reasons = []
- new_tokens = sum(len(response) for response in result.responses)
- tokens_per_second = new_tokens / batch_generation_time if batch_generation_time > 0 else 0
+ for i, result in enumerate(batch_results):
+ result_tokens = sum(len(response) for response in result.responses)
+ total_new_tokens += result_tokens
+ all_response_lengths.extend([len(response) for response in result.responses])
+ all_finish_reasons.extend(result.finish_reasons)
+ tokens_per_second = total_new_tokens / batch_generation_time if batch_generation_time > 0 else 0
result_dict = {
"tokens_per_second": tokens_per_second,
"generation_time": batch_generation_time,
- "num_new_tokens": new_tokens,
- "finish_reasons": collections.Counter(result.finish_reasons),
- "response_lengths": [len(response) for response in result.responses],
- "dataset_indices": result.dataset_index,
+ "num_new_tokens": total_new_tokens,
+ "finish_reasons": collections.Counter(all_finish_reasons),
+ "response_lengths": all_response_lengths,
+ "dataset_indices": [r.dataset_index for r in batch_results],
}
- # Get prompt lengths using dataset indices from the result
- prompt_data = dataset[result.dataset_index]
- prompts = prompt_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]
- prompt_lengths = [len(prompt) for prompt in prompts]
- response_lengths = [len(response) for response in result.responses]
-
- # Calculate total FLOPs for all prompts and responses in the batch
- # No need to expand prompt_lengths - the flops method now handles samples_per_prompt
+ prompt_lengths = []
+ response_lengths = all_response_lengths
+ for r in batch_results:
+ prompt_data = dataset[r.dataset_index]
+ prompt = prompt_data[dataset_transformation.INPUT_IDS_PROMPT_KEY]
+ prompt_lengths.append(len(prompt))
+
+ result_dict["prompt_lengths"] = prompt_lengths
model_flops = model_dims.flops(
prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout
)
-
- # MFU = (FLOPs / time) / peak_FLOPS * 100
- model_flops_per_second = model_flops / batch_generation_time if batch_generation_time > 0 else 0
+ model_flops_per_second = model_flops / batch_generation_time
result_dict["mfu"] = 100 * model_flops_per_second / device_flops
# Calculate total memory bytes for all prompts and responses in the batch
@@ -809,18 +842,36 @@ def run_benchmark(
f"MFU: {result_dict['mfu']:.2f}%, "
f"MBU: {result_dict['mbu']:.2f}%, "
f"generation time: {batch_generation_time:.2f}s, "
- f"total new tokens: {new_tokens}"
+ f"total new tokens: {total_new_tokens}"
)
- # Calculate total time for main benchmark only
- main_benchmark_time = sum(r["generation_time"] for r in results)
+ main_benchmark_end_time = time.time()
+ main_benchmark_total_time = main_benchmark_end_time - main_benchmark_start_time
+
+ total_main_tokens = sum(r["num_new_tokens"] for r in results)
+ overall_tokens_per_second = (
+ total_main_tokens / main_benchmark_total_time if main_benchmark_total_time > 0 else 0
+ )
- print_summary(results, main_benchmark_time, args, model_config)
- save_benchmark_results_to_csv(results, main_benchmark_time, args, model_config)
+ logger.info("\nOverall main benchmark performance:")
+ logger.info(f" Total wall-clock time: {main_benchmark_total_time:.2f}s")
+ logger.info(f" Total tokens generated: {total_main_tokens}")
+ logger.info(f" Overall tokens/second: {overall_tokens_per_second:.2f}")
+ print_summary(results, main_benchmark_total_time, overall_tokens_per_second, args, model_config)
+ save_benchmark_results_to_csv(
+ results, main_benchmark_total_time, overall_tokens_per_second, args, model_config
+ )
finally:
+ logger.info("Starting cleanup...")
stop_event.set()
- executor.shutdown(wait=True)
+
+ logger.info("Waiting for threads to complete...")
+ try:
+ submission_future.result(timeout=5)
+ generation_future.result(timeout=10)
+ finally:
+ executor.shutdown(wait=True)
logger.info("Threads cleaned up")
@@ -867,7 +918,11 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
def print_summary(
- results: list[dict[str, Any]], total_time: float, args: grpo_fast.Args, model_config: model_utils.ModelConfig
+ results: list[dict[str, Any]],
+ total_time: float,
+ overall_tokens_per_second: float,
+ args: grpo_fast.Args,
+ model_config: model_utils.ModelConfig,
) -> None:
"""Print benchmark summary statistics."""
@@ -885,8 +940,11 @@ def print_summary(
print(f"Num rollouts: {args.num_samples_per_prompt_rollout}")
print(f"Max tokens: {args.response_length}")
print("-" * 60)
+ print(f"Total wall-clock time (main benchmark): {total_time:.2f}s")
+ print(f"Total new tokens generated: {agg_results['total_num_new_tokens']}")
print(f"Total time (main benchmark): {agg_results['total_generation_time']:.2f}s")
print(f"Total new tokens generated: {agg_results['total_num_new_tokens']}")
+ print(f"Overall tokens/second: {overall_tokens_per_second:.2f}")
print("-" * 60)
print(f"Average results over {len(results)} main benchmark batches:")
print(f"Average tokens/second: {agg_results['avg_tokens_per_second']:.2f}")
diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py
index 2ff67be1d..0f71fb8ff 100644
--- a/open_instruct/grpo_fast.py
+++ b/open_instruct/grpo_fast.py
@@ -202,6 +202,10 @@ class Args:
fused_optimizer: bool = False
"""Whether to use fused optimizer"""
+ # Progress reporting
+ update_every: int = 10
+ """How often to update the Beaker description progress bar."""
+
# Batch sizes
per_device_train_batch_size: int = 1
"""The forward batch size per device (local_micro_batch_size)"""
@@ -329,6 +333,8 @@ class Args:
on the first node and 4 learner processes on the second node; each process will have 1 GPU)"""
vllm_num_engines: int = 1
"""number of vLLM Engines, set to 0 to disable vLLM"""
+ inflight_updates: bool = False
+ """If True, return immediately even with pending work. If False, wait for all work to complete before exiting."""
inference_batch_size: Optional[int] = None
"""inference batch size per vLLM engine. If None, calculated as ceil(num_unique_prompts_rollout / vllm_num_engines) * num_samples_per_prompt_rollout"""
vllm_tensor_parallel_size: int = 1
@@ -343,6 +349,8 @@ class Args:
"""whether to enable prefix caching"""
vllm_top_p: float = 1.0
"""vLLM top p for nucleus sampling"""
+ inference_batch_size: Optional[int] = None
+ """Number of inference requests to batch together for vLLM processing"""
deepspeed_stage: int = 0
"""the deepspeed stage"""
gather_whole_model: bool = True
@@ -433,6 +441,7 @@ def __post_init__(self):
assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!"
if self.num_samples_per_prompt_rollout == 1:
logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.")
+
assert self.apply_verifiable_reward or self.apply_r1_style_format_reward or self.non_stop_penalty, (
"At least one reward must be applied!"
)
@@ -444,6 +453,8 @@ def __post_init__(self):
# Initialize stop_strings if None
if self.stop_strings is None:
self.stop_strings = []
+ if self.inference_batch_size is None:
+ self.inference_batch_size = self.num_unique_prompts_rollout // self.vllm_num_engines
assert self.pack_length >= self.max_prompt_token_length + self.response_length, (
"The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!"
)
@@ -534,6 +545,87 @@ def to_device_inplace(tensors_list: List[torch.Tensor], device: torch.device):
tensors_list[i] = tensors_list[i].to(device, non_blocking=True)
+def log_memory_usage(
+ rank: int, device: torch.device, stage: str, include_tensors: bool = False, top_k_tensors: int = 10
+):
+ """Log comprehensive memory usage information for debugging OOM errors.
+
+ Args:
+ rank: Process rank
+ device: CUDA device
+ stage: String describing the current stage of execution
+ include_tensors: Whether to log information about individual tensors
+ top_k_tensors: Number of largest tensors to log
+ """
+ import gc
+
+ import psutil
+
+ # System RAM usage
+ process = psutil.Process()
+ ram_usage_gb = process.memory_info().rss / (1024**3)
+ ram_percent = process.memory_percent()
+
+ # Get available system memory
+ vm = psutil.virtual_memory()
+ available_ram_gb = vm.available / (1024**3)
+ total_ram_gb = vm.total / (1024**3)
+
+ logger.info(f"[Rank {rank}] Memory @ {stage}:")
+ logger.info(
+ f" System RAM: {ram_usage_gb:.2f}GB ({ram_percent:.1f}%) | Available: {available_ram_gb:.2f}GB / {total_ram_gb:.2f}GB"
+ )
+
+ if torch.cuda.is_available():
+ # Force synchronization to get accurate memory stats
+ torch.cuda.synchronize(device)
+
+ # GPU memory usage
+ allocated_gb = torch.cuda.memory_allocated(device) / (1024**3)
+ reserved_gb = torch.cuda.memory_reserved(device) / (1024**3)
+ max_allocated_gb = torch.cuda.max_memory_allocated(device) / (1024**3)
+
+ # Get total GPU memory
+ total_gpu_memory = torch.cuda.get_device_properties(device).total_memory / (1024**3)
+ free_memory = (torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_reserved(device)) / (
+ 1024**3
+ )
+
+ logger.info(" GPU Memory:")
+ logger.info(
+ f" Allocated: {allocated_gb:.2f}GB / Reserved: {reserved_gb:.2f}GB / Total: {total_gpu_memory:.2f}GB"
+ )
+ logger.info(f" Max Allocated: {max_allocated_gb:.2f}GB | Free: {free_memory:.2f}GB")
+ logger.info(f" Utilization: {(reserved_gb / total_gpu_memory * 100):.1f}%")
+
+ # Memory summary from PyTorch
+ if hasattr(torch.cuda, "memory_summary"):
+ summary = torch.cuda.memory_summary(device, abbreviated=True)
+ logger.debug(f" PyTorch Memory Summary:\n{summary}")
+
+ if include_tensors:
+ # Find and log largest tensors
+ tensor_sizes = []
+ for obj in gc.get_objects():
+ try:
+ if torch.is_tensor(obj) and obj.is_cuda and obj.device == device:
+ size_mb = float(obj.element_size() * obj.nelement() / (1024**2))
+ tensor_sizes.append((size_mb, obj.shape, obj.dtype, obj.device))
+ except Exception:
+ pass
+
+ if tensor_sizes:
+ tensor_sizes.sort(reverse=True, key=lambda x: x[0])
+ logger.info(f" Top {min(top_k_tensors, len(tensor_sizes))} Tensors on GPU:")
+ for i, (size_mb, shape, dtype, dev) in enumerate(tensor_sizes[:top_k_tensors]):
+ logger.info(f" {i + 1}. {size_mb:.2f}MB | Shape: {shape} | Dtype: {dtype}")
+
+ total_tracked_mb = float(sum(x[0] for x in tensor_sizes))
+ logger.info(
+ f" Total tracked tensor memory: {total_tracked_mb:.2f}MB ({total_tracked_mb / 1024:.2f}GB)"
+ )
+
+
class ShufflingIterator:
def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None):
self.data = data.copy()
@@ -630,6 +722,9 @@ def load(self, path: str, map_location=None):
dschf = None
logger.info(f"Deepspeed config: {dschf=}")
+ # Log memory before loading policy model
+ log_memory_usage(self.rank, self.device, "Before loading policy model", include_tensors=False)
+
self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
@@ -639,6 +734,9 @@ def load(self, path: str, map_location=None):
)
disable_dropout_in_model(self.policy)
self.policy.gradient_checkpointing_enable()
+
+ # Log memory after loading policy model
+ log_memory_usage(self.rank, self.device, "After loading policy model", include_tensors=False)
# AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam
# AdamOptimizer = FusedAdam
if args.set_weight_decay_on_bias_and_norm:
@@ -728,6 +826,9 @@ def load(self, path: str, map_location=None):
dschf = None
logger.info(f"DeepSpeed config: {dschf=}")
+ # Log memory before loading reference policy
+ log_memory_usage(self.rank, self.device, "Before loading ref_policy", include_tensors=False)
+
self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
@@ -739,6 +840,9 @@ def load(self, path: str, map_location=None):
self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config)
self.ref_policy.eval()
+ # Log memory after loading reference policy
+ log_memory_usage(self.rank, self.device, "After loading ref_policy", include_tensors=True, top_k_tensors=10)
+
# Load reference policy checkpoint if available
if hasattr(self, "ref_policy_checkpoint_path") and self.ref_policy_checkpoint_path:
state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device)
@@ -885,12 +989,19 @@ def train(
num_mini_batches: int,
):
args = self.args
+
+ # Log memory before moving tensors to device
+ log_memory_usage(self.rank, self.device, "Before to_device", include_tensors=False)
+
to_device_inplace(collated_query_responses, self.device)
to_device_inplace(collated_tool_masks, self.device)
to_device_inplace(collated_attention_masks, self.device)
to_device_inplace(collated_position_ids, self.device)
to_device_inplace(collated_advantages, self.device)
to_device_inplace(collated_response_masks, self.device)
+
+ # Log memory after moving tensors to device
+ log_memory_usage(self.rank, self.device, "After to_device", include_tensors=True, top_k_tensors=5)
# accumulation steps should always be at least 1
accumulation_steps = max(math.ceil(len(collated_query_responses) / num_mini_batches - 0.5), 1)
leftover = len(collated_query_responses) % accumulation_steps
@@ -963,6 +1074,9 @@ def train(
old_logprobs[i] = old_logprob
torch.cuda.empty_cache()
+ # Log memory before training loop
+ log_memory_usage(self.rank, self.device, "Before training loop", include_tensors=False)
+
local_step = 0
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
with Timer("[Training Processes] Loss calculation", noop=self.rank != 0):
@@ -1038,7 +1152,23 @@ def train(
# grpo change: directly subtract KL in loss (add)
loss = masked_mean(pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis)
loss = loss / accumulation_steps
+
+ # Log memory before backward pass
+ log_memory_usage(
+ self.rank, self.device, f"Before backward (step {local_step})", include_tensors=False
+ )
+
self.model.backward(loss)
+
+ # Log memory after backward pass
+ log_memory_usage(
+ self.rank,
+ self.device,
+ f"After backward (step {local_step})",
+ include_tensors=True,
+ top_k_tensors=10,
+ )
+
if (local_step + 1) % accumulation_steps == 0:
self.model.step()
local_step += 1
@@ -1344,17 +1474,19 @@ def accumulate_inference_batches(
args: Args,
training_step: int,
generation_config,
+ num_prompts: int,
actor_manager=None,
timeout: Optional[float] = None,
) -> tuple[GenerationResult, Batch]:
"""Accumulate multiple inference results into a single training batch.
Args:
- inference_results_Q: Queue containing GenerationResult objects
+ inference_results_Q: Queue containing individual GenerationResult objects (one per prompt)
pending_queries_map: PendingQueriesMap instance for thread-safe query tracking
- args: Arguments containing vllm_num_engines
+ args: Arguments containing vllm_num_engines and batch size info
training_step: Current training step for error reporting
generation_config: Generation config containing n (number of samples per prompt)
+ num_prompts: Number of prompts to accumulate
timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely.
Raises:
@@ -1363,15 +1495,15 @@ def accumulate_inference_batches(
Returns:
Tuple of (combined_result, Batch with queries, ground_truths, datasets) or (ShutdownSentinel, None) if shutdown signal received
"""
- # Collect results from all engines with non-blocking progress bar
results = []
all_queries = []
all_ground_truths = []
all_datasets = []
+
for i in tqdm(
- range(args.vllm_num_engines),
- total=args.vllm_num_engines,
- desc=f"Accumulating results from {args.vllm_num_engines} engines",
+ range(num_prompts),
+ total=num_prompts,
+ desc=f"Accumulating {num_prompts} results (each with {generation_config.n} completions)",
bar_format="{l_bar}{bar}{r_bar}\n",
disable=not args.verbose,
):
@@ -1379,37 +1511,17 @@ def accumulate_inference_batches(
if isinstance(result, ShutdownSentinel):
return result, None
- dataset_indices = result.dataset_index
-
- if dataset_indices is None:
- raise RuntimeError(f"Dataset indices is None for result {i}")
-
- # When generation_config.n > 1, vLLM generates multiple responses per prompt
- # but dataset_indices only contains the unique indices (not replicated)
- # So we expect: len(responses) == len(dataset_indices) * generation_config.n
- expected_responses = len(dataset_indices) * generation_config.n
- assert len(result.responses) == expected_responses, (
- f"Mismatch: number of responses ({len(result.responses)}) "
- f"doesn't match expected ({expected_responses}) for result {i}"
- f". {generation_config.n=}"
- f", {len(dataset_indices)=}"
- )
- # Get corresponding queries, ground_truths, datasets for each individual prompt
- batch_queries = []
- batch_ground_truths = []
- batch_datasets = []
-
- for dataset_idx in dataset_indices:
- query, ground_truth, dataset = pending_queries_map.pop(dataset_idx)
- batch_queries.append(query)
- batch_ground_truths.append(ground_truth)
- batch_datasets.append(dataset)
+ dataset_index = result.dataset_index
+ assert len(result.responses) == generation_config.n, (
+ f"Result {i} has {len(result.responses)} responses, expected {generation_config.n}"
+ )
+ query, ground_truth, dataset = pending_queries_map.pop(dataset_index)
+ all_queries.append(query)
+ all_ground_truths.append(ground_truth)
+ all_datasets.append(dataset)
results.append(result)
- all_queries.extend(batch_queries)
- all_ground_truths.extend(batch_ground_truths)
- all_datasets.extend(batch_datasets)
# Combine all results into a single GenerationResult
combined_responses = []
@@ -1492,7 +1604,13 @@ def data_preparation_thread(
# Streaming accumulation: collect results as they arrive
with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer:
result, batch = accumulate_inference_batches(
- inference_results_Q, pending_queries_map, args, training_step, generation_config, actor_manager
+ inference_results_Q,
+ pending_queries_map,
+ args,
+ training_step,
+ generation_config,
+ args.num_unique_prompts_rollout,
+ actor_manager
)
if isinstance(result, ShutdownSentinel):
logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting")
@@ -1626,10 +1744,6 @@ def data_preparation_thread(
finish_reasons += [finish_reasons[i] for i in sampled_indices]
- print(
- f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses"
- )
-
with Timer("📦 [Data Preparation Thread] Packing sequences"):
packed_sequences = pack_sequences(
queries=batch.queries,
@@ -1656,7 +1770,8 @@ def data_preparation_thread(
shortfall = args.world_size - len(packed_sequences.query_responses)
if shortfall > 0:
logger.warning(
- f"Padding {shortfall} sequences for world size. In future, you should adjust your compute this."
+ f"[Data Preparation Thread] Step {training_step}: Padding {shortfall} sequences for world size. "
+ f"In future, you should adjust your compute this."
)
# construct "dummy" sequences for padding out the world size
dummy_qr = torch.tensor([tokenizer.pad_token_id, tokenizer.eos_token_id], dtype=torch.long)
@@ -1855,6 +1970,7 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod
name=args.run_name,
save_code=True,
tags=[args.exp_name] + get_wandb_tags(),
+ settings={"quiet": False, "show_errors": True, "show_warnings": True, "silent": False},
)
wandb_url = wandb.run.get_url()
logger.info(f"Initial Beaker description update with wandb_url: {wandb_url}")
@@ -1923,7 +2039,6 @@ def create_model_and_optimizer(
ray_get_with_progress([pg.ready()], desc="Waiting for placement group")
inits = []
policy_group = ModelGroup(pg, PolicyTrainerRayProcess, args.num_learners_per_node, args.single_gpu_mode)
- wandb_url = wandb.run.get_url() if args.with_tracking else None
inits.extend(
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer)
for model in policy_group.models
@@ -1976,6 +2091,7 @@ def create_model_and_optimizer(
max_len,
args.vllm_gpu_memory_utilization,
args.single_gpu_mode,
+ args.inference_batch_size,
pg=pg if args.single_gpu_mode else None,
tools=tool_objects,
max_tool_calls=args.max_tool_calls,
@@ -1983,6 +2099,7 @@ def create_model_and_optimizer(
results_queue=inference_results_Q,
eval_results_queue=evaluation_inference_results_Q,
actor_manager=actor_manager,
+ inflight_updates=args.inflight_updates,
)
resume_training_step = ray_get_with_progress(inits, desc="Initializing models")[0] + 1
@@ -2037,28 +2154,14 @@ def split_and_insert_batch(
is_eval: bool = False,
) -> None:
"""Split a batch into multiple inference batches and insert individual prompts into queues and mapping."""
- for batch_idx in range(vllm_num_engines):
- start_idx = batch_idx * args.inference_batch_size
- end_idx = min(start_idx + args.inference_batch_size, len(batch.queries))
-
- # Stop if we've distributed all queries
- if start_idx >= len(batch.queries):
- break
-
- sub_batch = batch[start_idx:end_idx]
-
- # Store prompts in the map using thread-safe insert_many
- pending_queries_map.insert_many(
- sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets
- )
-
- # Use PromptRequest for Ray queue with batch-specific dataset_index list
+ pending_queries_map.insert_many(batch.indices, batch.queries, batch.ground_truths, batch.datasets)
+ for i, prompt in enumerate(batch.queries):
param_prompt_Q.put(
PromptRequest(
- prompts=sub_batch.queries,
+ prompt=prompt,
generation_config=generation_config,
training_step=training_step,
- dataset_index=sub_batch.indices,
+ dataset_index=batch.indices[i],
is_eval=is_eval,
)
)
@@ -2101,7 +2204,7 @@ def load_data_from_packing_thread(
logger.warning("[Main Thread] Stop event detected while waiting for packed sequences")
return None, {}, num_total_tokens
try:
- packed_data = packed_sequences_Q.get(timeout=30.0)
+ packed_data = packed_sequences_Q.get(timeout=60.0)
break
except Empty:
# check that everything is still alive
@@ -2182,9 +2285,11 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera
enable=args.verbose,
)
num_processed = sum(int(result) for result in processed_results)
+ logger.info(f"[Generate Thread] vLLM engines returned, processed {num_processed} total requests")
# Suppress timing output if nothing was processed
if num_processed == 0:
timer.noop = True
+ time.sleep(10)
if num_processed > 0:
try:
generate_metrics_Q.put_nowait({"time/generation": timer.duration})
@@ -2282,7 +2387,21 @@ def one_training_step(
if isinstance(value, np.ndarray) or isinstance(value, list):
if len(value) > 0:
metrics[key] = wandb.Histogram(value)
- wandb.log(metrics, step=episode)
+ logger.info(f"About to log to wandb... step={episode}")
+ result = wandb.log(metrics, step=episode)
+ logger.info(f"WandB log() returned: {result}, logged {metrics=} to wandb.")
+
+ # Debug: Check if wandb.run is active and what's in the summary
+ if wandb.run:
+ logger.info(f"WandB run is active. Run ID: {wandb.run.id}")
+ logger.info(f"WandB run name: {wandb.run.name}")
+ logger.info(f"WandB run URL: {wandb.run.get_url()}")
+ # Check if data was actually logged
+ summary_keys = list(wandb.run.summary.keys()) if wandb.run.summary else []
+ logger.info(f"WandB summary has {len(summary_keys)} keys: {summary_keys[:10]}...") # Show first 10 keys
+ else:
+ logger.error("WandB run is not active!")
+ logger.info("Done step.")
def maybe_evaluate(
@@ -2296,6 +2415,7 @@ def maybe_evaluate(
eval_pending_queries_map: PendingQueriesMap,
eval_generation_config,
generate_metrics_Q: Queue,
+ num_eval_prompts: int,
actor_manager=None,
):
"""Optionally evaluate the model."""
@@ -2311,6 +2431,7 @@ def maybe_evaluate(
args,
training_step,
eval_generation_config,
+ num_prompts=num_eval_prompts,
actor_manager,
timeout=timeout,
)
@@ -2548,6 +2669,7 @@ def run_training(
tokenizer,
train_dataset,
eval_batch,
+ num_eval_prompts,
policy_group,
vllm_engines,
generation_configs,
@@ -2641,7 +2763,7 @@ def health_check_fn():
if (
training_step == resume_training_step
- or training_step % 10 == 0
+ or training_step % args.update_every == 0
or training_step == args.num_training_steps
):
logger.info(f"Progress update for Beaker description: step {training_step}/{args.num_training_steps}")
@@ -2754,6 +2876,7 @@ def health_check_fn():
eval_pending_queries_map,
generation_configs["eval"],
generate_metrics_Q,
+ num_eval_prompts,
actor_manager,
)
@@ -2778,10 +2901,11 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
ray.init(dashboard_host="0.0.0.0")
# Create Ray queues.
- queue_size = (args.async_steps + 1) * args.vllm_num_engines
+ queue_size = (args.async_steps + 1) * args.num_unique_prompts_rollout
inference_results_Q = ray_queue.Queue(maxsize=queue_size)
param_prompt_Q = ray_queue.Queue(maxsize=queue_size)
- evaluation_inference_results_Q = ray_queue.Queue(maxsize=args.vllm_num_engines)
+ # Size eval queue based on actual eval samples (num_eval_samples) with 2x overhead for buffering.
+ evaluation_inference_results_Q = ray_queue.Queue(maxsize=num_eval_samples * 2)
policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager = (
create_model_and_optimizer(
@@ -2826,9 +2950,11 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
if eval_dataset is None:
eval_batch = None
+ num_eval_prompts = 0
else:
eval_dataset_indices = list(range(min(num_eval_samples, len(eval_dataset))))
eval_batch = next_batch(eval_dataset_indices, eval_dataset)
+ num_eval_prompts = len(eval_dataset_indices)
reward_fn = make_reward_fn(args)
stop_event = threading.Event()
@@ -2840,6 +2966,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
tokenizer,
train_dataset,
eval_batch,
+ num_eval_prompts,
policy_group,
vllm_engines,
generation_configs,
@@ -2862,6 +2989,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
actor_manager,
checkpoint_state,
)
+ logger.info("Done training!")
finally:
cleanup_training_resources(
stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager
diff --git a/open_instruct/queue_types.py b/open_instruct/queue_types.py
index 87f247ad6..78bb5d760 100644
--- a/open_instruct/queue_types.py
+++ b/open_instruct/queue_types.py
@@ -31,7 +31,7 @@ class GenerationResult:
finish_reasons: List[str]
masks: List[List[int]]
request_info: RequestInfo
- dataset_index: Optional[List[int]] = None
+ dataset_index: Optional[int] = None
training_step: Optional[int] = None
token_statistics: Optional[TokenStatistics] = None
start_time: Optional[float] = None
@@ -46,9 +46,9 @@ class PromptRequest:
`_QueueActor`.
"""
- prompts: List[List[int]]
+ prompt: List[int]
generation_config: Any
training_step: Optional[int] = None
- dataset_index: Optional[List[int]] = None
+ dataset_index: Optional[int] = None
is_eval: bool = False
start_time: Optional[float] = None
diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py
index be9cbd417..f2edbc323 100644
--- a/open_instruct/test_grpo_fast.py
+++ b/open_instruct/test_grpo_fast.py
@@ -1,5 +1,7 @@
import gc
import os
+import queue
+import time
import unittest
from unittest.mock import Mock
@@ -11,9 +13,8 @@
from transformers import AutoTokenizer
from vllm import SamplingParams
-from open_instruct import grpo_fast, model_utils, utils
+from open_instruct import grpo_fast, model_utils, tool_utils, utils, vllm_utils3
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo
-from open_instruct.vllm_utils3 import create_vllm_engines
class TestGrpoFastBase(unittest.TestCase):
@@ -67,19 +68,10 @@ def setUp(self):
# Initialize Ray for this test
ray.init(include_dashboard=False)
- def _cleanup_ray_queues(self):
- """Clean up all Ray queues created during the test."""
- for queue in self._ray_queues:
- try:
- queue.shutdown()
- except Exception as e:
- print(f"Warning: Failed to shutdown Ray queue: {e}")
- self._ray_queues.clear()
-
def tearDown(self):
"""Check for leaks and shutdown Ray."""
# Clean up Ray queues BEFORE shutting down Ray
- self._cleanup_ray_queues()
+ [rq.shutdown() for rq in self._ray_queues]
# Shutdown Ray
if ray.is_initialized():
@@ -126,17 +118,18 @@ def create_test_data(self, num_prompts, prefix="", start_idx=0):
datasets = [f"{prefix}dataset_{i}" for i in indices]
return queries, ground_truths, datasets, indices
- def create_mock_args(self, num_engines=4, num_samples=1):
+ def create_mock_args(self, num_engines=4, num_samples=1, num_prompts=16):
"""Create mock args object."""
mock_args = Mock()
mock_args.vllm_num_engines = num_engines
mock_args.num_samples_per_prompt_rollout = num_samples
+ mock_args.num_unique_prompts_rollout = num_prompts
+ mock_args.verbose = False
return mock_args
- def create_mock_result(self, dataset_indices, training_step, num_samples_per_prompt=1):
+ def create_mock_result(self, dataset_index, training_step, num_samples_per_prompt=1):
"""Create a mock GenerationResult."""
- batch_size = len(dataset_indices)
- total_responses = batch_size * num_samples_per_prompt
+ total_responses = num_samples_per_prompt
return GenerationResult(
responses=[[1, 2, 3] for _ in range(total_responses)],
@@ -150,13 +143,15 @@ def create_mock_result(self, dataset_indices, training_step, num_samples_per_pro
tool_runtimes=[0.0] * total_responses,
tool_calleds=[False] * total_responses,
),
- dataset_index=dataset_indices,
+ dataset_index=dataset_index,
)
def setup_and_split_batch(self, queries, ground_truths, datasets, indices, num_engines, training_step=1):
"""Setup queues and split batch - common pattern."""
- param_prompt_Q = ray_queue.Queue(maxsize=num_engines * 2)
- inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2)
+ # Queue size should accommodate batches from all engines potentially multiple times
+ queue_size = num_engines * len(queries)
+ param_prompt_Q = ray_queue.Queue(maxsize=queue_size)
+ inference_results_Q = ray_queue.Queue(maxsize=queue_size)
pending_queries_map = grpo_fast.PendingQueriesMap()
# Track queues for cleanup
@@ -205,7 +200,7 @@ def test_vllm_queue_system_single_prompt(self):
self._ray_queues.extend([param_prompt_Q, inference_results_Q])
# Create vLLM engines with queues
- vllm_engines = create_vllm_engines(
+ vllm_engines = vllm_utils3.create_vllm_engines(
num_engines=1,
tensor_parallel_size=1,
enforce_eager=True,
@@ -268,43 +263,49 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int,
queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines
)
- # Verify that we have individual prompts in the map (not batches)
+ # Verify that we have all prompts in the map
self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout)
- # Verify that we have the expected number of items in the queue
- self.assertEqual(param_prompt_Q.qsize(), vllm_num_engines)
+ # Verify that we have individual prompts in the queue (changed from batches)
+ self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout)
- # Simulate vLLM processing
- batch_idx = 0
+ # Simulate vLLM processing - each prompt gets processed individually
+ prompt_count = 0
while not param_prompt_Q.empty():
request = param_prompt_Q.get()
self.assertIsInstance(request, PromptRequest)
self.assertEqual(request.training_step, 1)
- self.assertIsInstance(request.dataset_index, list)
+ self.assertIsInstance(request.dataset_index, int) # Now individual prompts have single index
- mock_result = self.create_mock_result(request.dataset_index, request.training_step)
+ # Create result for this individual prompt with n=4 samples
+ mock_result = self.create_mock_result(
+ request.dataset_index, request.training_step, num_samples_per_prompt=4
+ )
inference_results_Q.put(mock_result)
- batch_idx += 1
+ prompt_count += 1
+
+ # Verify we processed the right number of individual prompts
+ self.assertEqual(prompt_count, num_unique_prompts_rollout)
- # Simulate streaming accumulation (simplified version for testing)
+ # Simulate accumulation
combined_responses = []
combined_queries = []
combined_ground_truths = []
combined_datasets = []
- for _ in range(vllm_num_engines):
+ # Process all results (we have num_unique_prompts_rollout individual results)
+ for _ in range(num_unique_prompts_rollout):
result = inference_results_Q.get()
- dataset_indices = result.dataset_index
+ dataset_index = result.dataset_index
- # Get queries from pending_queries_map
+ # Get query for this index
batch_queries = []
batch_ground_truths = []
batch_datasets = []
- for idx in dataset_indices:
- q, gt, d = pending_queries_map.pop(idx)
- batch_queries.append(q)
- batch_ground_truths.append(gt)
- batch_datasets.append(d)
+ q, gt, d = pending_queries_map.pop(dataset_index)
+ batch_queries.append(q)
+ batch_ground_truths.append(gt)
+ batch_datasets.append(d)
combined_responses.extend(result.responses)
combined_queries.extend(batch_queries)
@@ -333,9 +334,10 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int,
# Verify that the combined result has the correct structure
self.assertIsInstance(combined_result, GenerationResult)
- self.assertEqual(len(combined_result.responses), len(queries_next))
- self.assertEqual(len(combined_result.finish_reasons), len(queries_next))
- self.assertEqual(len(combined_result.masks), len(queries_next))
+ # With n=4 samples per prompt, we expect 4x the number of responses
+ self.assertEqual(len(combined_result.responses), len(queries_next) * 4)
+ self.assertEqual(len(combined_result.finish_reasons), len(queries_next) * 4)
+ self.assertEqual(len(combined_result.masks), len(queries_next) * 4)
# Verify that the pending_queries_map is empty after accumulation
self.assertEqual(len(pending_queries_map), 0)
@@ -358,28 +360,28 @@ def test_dataset_index_preservation_through_pipeline(self):
queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines
)
- # Simulate vLLM processing
- batch_idx = 0
+ # Simulate vLLM processing - processes individual prompts
+ prompt_count = 0
while not param_prompt_Q.empty():
request = param_prompt_Q.get()
mock_result = self.create_mock_result(request.dataset_index, request.training_step)
inference_results_Q.put(mock_result)
- batch_idx += 1
+ prompt_count += 1
- # Simulate streaming accumulation
+ # Simulate accumulation
combined_queries = []
combined_ground_truths = []
combined_datasets = []
- for _ in range(vllm_num_engines):
+ # Process all individual results
+ for _ in range(prompt_count):
result = inference_results_Q.get()
- dataset_indices = result.dataset_index
+ dataset_index = result.dataset_index
- for idx in dataset_indices:
- q, gt, d = pending_queries_map.pop(idx)
- combined_queries.append(q)
- combined_ground_truths.append(gt)
- combined_datasets.append(d)
+ q, gt, d = pending_queries_map.pop(dataset_index)
+ combined_queries.append(q)
+ combined_ground_truths.append(gt)
+ combined_datasets.append(d)
# Verify results
self.assertEqual(combined_queries, queries_next)
@@ -402,32 +404,33 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe
queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines
)
- # Simulate vLLM processing with multiple samples
- batch_idx = 0
+ # Simulate vLLM processing with multiple samples - processes individual prompts
+ prompt_count = 0
while not param_prompt_Q.empty():
request = param_prompt_Q.get()
mock_result = self.create_mock_result(request.dataset_index, request.training_step, num_samples_per_prompt)
inference_results_Q.put(mock_result)
- batch_idx += 1
+ prompt_count += 1
- # Simulate streaming accumulation
+ # Simulate accumulation
combined_responses = []
combined_queries = []
combined_ground_truths = []
combined_datasets = []
- for _ in range(vllm_num_engines):
+ # Process all individual results
+ for _ in range(prompt_count):
result = inference_results_Q.get()
- dataset_indices = result.dataset_index
+ dataset_index = result.dataset_index
+ # Get query for this index
batch_queries = []
batch_ground_truths = []
batch_datasets = []
- for idx in dataset_indices:
- q, gt, d = pending_queries_map.pop(idx)
- batch_queries.append(q)
- batch_ground_truths.append(gt)
- batch_datasets.append(d)
+ q, gt, d = pending_queries_map.pop(dataset_index)
+ batch_queries.append(q)
+ batch_ground_truths.append(gt)
+ batch_datasets.append(d)
combined_responses.extend(result.responses)
combined_queries.extend(batch_queries)
@@ -521,12 +524,13 @@ def test_out_of_order_processing(self):
requests.append(param_prompt_Q.get())
# Put results back in REVERSE order to simulate out-of-order processing
+ # Create one result per prompt with n responses
for request in reversed(requests):
mock_result = self.create_mock_result(request.dataset_index, request.training_step, num_samples_per_prompt)
inference_results_Q.put(mock_result)
# Accumulate results
- mock_args = self.create_mock_args(num_engines, num_samples_per_prompt)
+ mock_args = self.create_mock_args(num_engines, num_samples_per_prompt, num_prompts)
# Create a mock generation config with n
mock_generation_config = Mock()
mock_generation_config.n = num_samples_per_prompt
@@ -537,6 +541,7 @@ def test_out_of_order_processing(self):
mock_args,
training_step=1,
generation_config=mock_generation_config,
+ num_prompts=num_prompts,
)
# Verify results work correctly even with out-of-order processing
@@ -592,8 +597,9 @@ def test_accumulate_waits_for_all_engines(self):
num_engines = 4
num_prompts = 16
- # Setup with results from only 3 engines
- inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2)
+ # Setup with results from only 3 engines (missing one)
+ # Queue size should accommodate results from engines
+ inference_results_Q = ray_queue.Queue(maxsize=num_engines * num_prompts)
# Track queue for cleanup
self._ray_queues.append(inference_results_Q)
@@ -604,15 +610,16 @@ def test_accumulate_waits_for_all_engines(self):
for i in range(num_prompts):
pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}")
- # Add results from only 3 engines (missing one)
- for engine_id in range(3):
- indices = list(range(engine_id * 4, (engine_id + 1) * 4))
- mock_result = self.create_mock_result(indices, 1)
+ # Add individual results (one per prompt) but missing some
+ # accumulate_inference_batches now expects individual results (num_prompts * n)
+ # Add results for only 12 prompts (missing 4)
+ for i in range(12): # Only 12 prompts, missing 4
+ mock_result = self.create_mock_result(i, 1)
inference_results_Q.put(mock_result)
- mock_args = self.create_mock_args(num_engines)
+ mock_args = self.create_mock_args(num_engines, num_prompts=num_prompts)
- # Test that accumulate blocks when missing an engine
+ # Test that accumulate blocks when missing results from the 4th engine
import threading
completed = threading.Event()
@@ -629,6 +636,7 @@ def run_accumulate():
mock_args,
training_step=1,
generation_config=mock_generation_config,
+ num_prompts=num_prompts,
)
completed.set()
except Exception:
@@ -637,14 +645,14 @@ def run_accumulate():
thread = threading.Thread(target=run_accumulate, daemon=True)
thread.start()
- # Should timeout waiting for 4th engine
+ # Should timeout waiting for missing results
self.assertFalse(completed.wait(timeout=1.0))
self.assertTrue(thread.is_alive())
- # Queue should be empty after consuming 3 results
+ # Queue should be empty after consuming 12 results
self.assertEqual(inference_results_Q.qsize(), 0)
- # Some entries should be removed
- self.assertLess(len(pending_queries_map), num_prompts)
+ # 12 entries should be removed from the map (4 still pending)
+ self.assertEqual(len(pending_queries_map), 4)
class TestStreamingAccumulation(TestGrpoFastBase):
@@ -695,19 +703,19 @@ def test_more_engines_than_queries(self):
while not param_prompt_Q.empty():
request = param_prompt_Q.get()
self.assertIsInstance(request, PromptRequest)
- self.assertEqual(len(request.prompts), 1, "Each batch should have exactly 1 prompt")
- batch_sizes.append(len(request.prompts))
+ self.assertIsNotNone(request.prompt, "Each request should have a prompt")
+ batch_sizes.append(1) # Each PromptRequest contains exactly 1 prompt
# All queries should be in the pending map
self.assertEqual(len(pending_queries_map), num_queries)
def test_uneven_distribution_no_empty_batches(self):
- """Test that uneven query distribution doesn't create empty batches."""
+ """Test that split_and_insert_batch creates one PromptRequest per query."""
num_engines = 3
- num_queries = 7 # 7/3 = ceil(2.33) = 3, so distribution should be [3, 3, 1]
+ num_queries = 7
queries, ground_truths, datasets, indices = self.create_test_data(num_queries)
- param_prompt_Q = ray_queue.Queue(maxsize=num_engines * 2)
+ param_prompt_Q = ray_queue.Queue(maxsize=num_queries * 2)
pending_queries_map = grpo_fast.PendingQueriesMap()
# Track queue for cleanup
@@ -735,24 +743,15 @@ def test_uneven_distribution_no_empty_batches(self):
args=mock_args,
)
- # Verify all batches have content and check distribution
- batch_sizes = []
+ # Verify we get one PromptRequest per query
+ request_count = 0
while not param_prompt_Q.empty():
request = param_prompt_Q.get()
- self.assertGreater(len(request.prompts), 0, "Found empty batch in queue!")
- batch_sizes.append(len(request.prompts))
-
- # Check the expected distribution
- self.assertEqual(sum(batch_sizes), num_queries, "Total queries should match")
- self.assertEqual(len(batch_sizes), num_engines, "Should have one batch per engine")
+ self.assertIsNotNone(request.prompt, "Each request should have a prompt")
+ request_count += 1
- # The distribution should be [3, 3, 1] for 7 queries across 3 engines with ceiling division
- expected_distribution = [3, 3, 1]
- self.assertEqual(
- sorted(batch_sizes, reverse=True),
- expected_distribution,
- f"Expected distribution {expected_distribution}, got {sorted(batch_sizes, reverse=True)}",
- )
+ # Should have exactly num_queries PromptRequests
+ self.assertEqual(request_count, num_queries, f"Should have {num_queries} PromptRequests")
def test_streaming_accumulation_basic(self):
"""Test basic streaming accumulation with in-order results."""
@@ -763,7 +762,8 @@ def test_streaming_accumulation_basic(self):
queries, ground_truths, datasets, indices = self.create_test_data(num_prompts)
# Create queues and maps
- inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2)
+ # Queue size should accommodate results from engines
+ inference_results_Q = ray_queue.Queue(maxsize=num_engines * num_prompts)
pending_queries_map = grpo_fast.PendingQueriesMap()
# Track queue for cleanup
@@ -773,45 +773,40 @@ def test_streaming_accumulation_basic(self):
for i in range(num_prompts):
pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i])
- # Create mock results with batch indices
- batch_size = num_prompts // num_engines
- for batch_idx in range(num_engines):
- start = batch_idx * batch_size
- end = start + batch_size
- mock_result = self.create_mock_result(list(range(start, end)), training_step=1)
+ # Create mock results - one per prompt
+ for i in range(num_prompts):
+ mock_result = self.create_mock_result(i, training_step=1)
inference_results_Q.put(mock_result)
# Simulate streaming accumulation logic
results_list = []
queries_list = []
- expected_batches = num_engines
+ expected_results = num_prompts # Now expecting one result per prompt
- while len(results_list) < expected_batches:
+ while len(results_list) < expected_results:
result = inference_results_Q.get()
- batch_idx = len(results_list)
results_list.append(result)
- # Get queries for this batch
- dataset_indices = result.dataset_index
+ # Get query for this prompt
+ dataset_index = result.dataset_index
batch_queries = []
batch_ground_truths = []
batch_datasets = []
- for idx in dataset_indices:
- q, gt, d = pending_queries_map.pop(idx)
- batch_queries.append(q)
- batch_ground_truths.append(gt)
- batch_datasets.append(d)
+ q, gt, d = pending_queries_map.pop(dataset_index)
+ batch_queries.append(q)
+ batch_ground_truths.append(gt)
+ batch_datasets.append(d)
queries_list.append((batch_queries, batch_ground_truths, batch_datasets))
- # Verify all batches processed
- self.assertEqual(len(results_list), expected_batches)
+ # Verify all results processed
+ self.assertEqual(len(results_list), expected_results)
self.assertEqual(len(pending_queries_map), 0)
# Combine in order
combined_queries = []
- for i in range(num_engines):
+ for i in range(num_prompts):
q, _, _ = queries_list[i]
combined_queries.extend(q)
@@ -828,7 +823,8 @@ def test_streaming_with_multiple_samples(self):
queries, ground_truths, datasets, indices = self.create_test_data(num_prompts)
# Create queues and maps
- inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2)
+ # Queue size should accommodate results from engines
+ inference_results_Q = ray_queue.Queue(maxsize=num_engines * num_prompts)
pending_queries_map = grpo_fast.PendingQueriesMap()
# Track queue for cleanup
@@ -839,15 +835,9 @@ def test_streaming_with_multiple_samples(self):
for _ in range(num_samples):
pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i])
- # Create results with multiple samples per prompt
- batch_size = num_prompts // num_engines
- for batch_idx in range(num_engines):
- start = batch_idx * batch_size
- end = start + batch_size
- dataset_indices = list(range(start, end))
-
- # Create result with num_samples responses per prompt
- mock_result = self.create_mock_result(dataset_indices, training_step=1, num_samples_per_prompt=num_samples)
+ # Create results - one per prompt with multiple samples
+ for i in range(num_prompts):
+ mock_result = self.create_mock_result(i, training_step=1, num_samples_per_prompt=num_samples)
inference_results_Q.put(mock_result)
# Process results
@@ -855,22 +845,125 @@ def test_streaming_with_multiple_samples(self):
while not inference_results_Q.empty():
result = inference_results_Q.get()
- # Verify number of responses matches num_samples * num_prompts_in_batch
- batch_prompts = len(result.dataset_index)
+ batch_prompts = 1 # Each result is for a single prompt now
expected_responses = batch_prompts * num_samples
self.assertEqual(len(result.responses), expected_responses)
total_responses += len(result.responses)
# Clean up pending_queries_map
- for idx in result.dataset_index:
- for _ in range(num_samples):
- if idx in pending_queries_map:
- pending_queries_map.pop(idx)
+ idx = result.dataset_index
+ for _ in range(num_samples):
+ if idx in pending_queries_map:
+ pending_queries_map.pop(idx)
# Verify total responses
self.assertEqual(total_responses, num_prompts * num_samples)
self.assertEqual(len(pending_queries_map), 0)
+ def test_vllm_tool_processing_completes(self):
+ """Test that tool processing completes without hanging using actual tool settings."""
+ # Check if CUDA is available
+ if not torch.cuda.is_available():
+ self.skipTest("CUDA is not available, skipping test")
+
+ # Create actual tools matching the script
+ tools = {
+ "": tool_utils.CodeExecutionTool(
+ start_str="",
+ end_str="",
+ api_endpoint="https://open-instruct-tool-server-10554368204.us-central1.run.app/execute",
+ ),
+ "": tool_utils.SearchTool(
+ start_str="",
+ end_str="",
+ api_endpoint="http://saturn-cs-aus-232.reviz.ai2.in:44177/search",
+ ),
+ }
+ max_tool_calls = {"": 5, "": 5}
+
+ # Create actual ActorManager via ray.remote
+ actor_manager = ray.remote(vllm_utils3.ActorManager).remote(should_stop=False)
+
+ # Create queues
+ prompt_queue = queue.Queue()
+ results_queue = queue.Queue()
+ eval_results_queue = queue.Queue()
+
+ # Create LLMRayActor
+ model_name = "EleutherAI/pythia-14m" # Small model for testing
+ actor = vllm_utils3.LLMRayActor(
+ model_name_or_path=model_name,
+ actor_id=0,
+ prompt_queue=prompt_queue,
+ results_queue=results_queue,
+ eval_results_queue=eval_results_queue,
+ actor_manager=actor_manager,
+ tools=tools,
+ max_tool_calls=max_tool_calls,
+ inference_batch_size=8,
+ vllm_kwargs={"gpu_memory_utilization": 0.3, "max_model_len": 512, "enable_prefix_caching": True},
+ )
+
+ tokenizer = actor.llm_engine.tokenizer
+
+ # Create test prompts that will trigger tools
+ test_prompts = [
+ "Write code to print hello: print('hello')",
+ "Search for Python tutorials: Python tutorial beginner",
+ "Calculate 2+2: print(2+2)",
+ "Find vLLM documentation: vLLM documentation",
+ "Write a simple function: def greet(): return 'hi'",
+ "Search machine learning: machine learning basics",
+ "Debug this code: x = 1; print(x)",
+ "Look up PyTorch: PyTorch tutorial",
+ ]
+
+ # Create sampling params once
+ sampling_params = SamplingParams(
+ temperature=1.0, top_p=1.0, n=1, max_tokens=50, stop=["", "", ""]
+ )
+
+ # Add all requests to queue
+ for i in range(16):
+ prompt_text = test_prompts[i % len(test_prompts)]
+ prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=True)
+
+ request = PromptRequest(
+ prompt=prompt_ids, generation_config=sampling_params, training_step=0, dataset_index=i, is_eval=False
+ )
+ prompt_queue.put(request)
+
+ # Set a timeout and process
+ start_time = time.time()
+
+ # Process requests - this is what we're testing!
+ num_processed = actor.process_from_queue(timeout=30.0)
+
+ elapsed = time.time() - start_time
+
+ # Check results
+ results_received = []
+ while not results_queue.empty():
+ try:
+ result = results_queue.get_nowait()
+ results_received.append(result)
+ except queue.Empty:
+ break
+
+ # Verify we didn't hang
+ self.assertLess(elapsed, 60, "Should complete in less than 60 seconds")
+
+ # Verify we got results
+ self.assertGreater(num_processed, 0, "Should have processed some requests")
+ self.assertGreater(len(results_received), 0, "Should have received some results")
+
+ # Verify tool processing worked
+ for result in results_received:
+ self.assertIsNotNone(result.responses)
+ self.assertIsNotNone(result.request_info)
+ self.assertIsNotNone(result.request_info.tool_calleds)
+ # Check that at least some tools were called
+
class TestShufflingIterator(unittest.TestCase):
"""Test ShufflingIterator state preservation functionality."""
diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py
index bb427f391..e8d9d878e 100644
--- a/open_instruct/vllm_utils3.py
+++ b/open_instruct/vllm_utils3.py
@@ -43,7 +43,7 @@
from vllm.v1.core import kv_cache_utils
from open_instruct import logger_utils
-from open_instruct.queue_types import GenerationResult, RequestInfo, TokenStatistics
+from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics
from open_instruct.tool_utils.tool_vllm import MaxCallsExceededTool, Tool
from open_instruct.utils import ray_get_with_progress
@@ -201,37 +201,31 @@ def _process_outputs_with_tools(
def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=None, start_time=None):
- """Prepare final outputs based on whether tools were used."""
+ """Prepare final outputs with unified approach for all requests."""
+ outputs.sort(key=lambda x: (x.request_id.split("-")[0], int(x.request_id.split("-")[1])))
+
if not tools:
- outputs.sort(key=lambda x: int(x.request_id.split("_")[-1]))
return _process_outputs(
outputs, dataset_index=dataset_index, token_statistics=token_statistics, start_time=start_time
)
# Tool mode: add metadata and merge completions
- for req_id in tracking["masks"]:
+ for output in outputs:
+ req_id = output.request_id
+ if req_id not in tracking["masks"]:
+ # If the request ID is not in masks, it means it didn't go through tool processing.
+ continue
assert req_id in tracking["concat_outputs"], f"req_id {req_id} not in concat_outputs!"
- output = tracking["concat_outputs"][req_id].outputs[0]
- setattr(output, "mask", tracking["masks"][req_id])
- setattr(output, "num_calls", tracking["num_calls"][req_id])
- setattr(output, "timeout", tracking["timeout"][req_id])
- setattr(output, "tool_error", tracking["tool_error"][req_id])
- setattr(output, "tool_output", tracking["tool_output"][req_id])
- setattr(output, "tool_runtime", tracking["tool_runtime"][req_id])
- setattr(output, "tool_called", tracking["tool_called"][req_id])
-
- # Merge n completions into the same outputs
- merged_outputs = {}
- for req_id in tracking["concat_outputs"]:
- real_req_id, _ = req_id.split("-")
- if real_req_id not in merged_outputs:
- merged_outputs[real_req_id] = tracking["concat_outputs"][req_id]
- else:
- merged_outputs[real_req_id].outputs.append(tracking["concat_outputs"][req_id].outputs[0])
-
- final_outputs = sorted(
- merged_outputs.values(), key=lambda x: (int(x.request_id.split("-")[0]), int(x.request_id.split("-")[1]))
- )
+ tool_output = tracking["concat_outputs"][req_id].outputs[0]
+ setattr(tool_output, "mask", tracking["masks"][req_id])
+ setattr(tool_output, "num_calls", tracking["num_calls"][req_id])
+ setattr(tool_output, "timeout", tracking["timeout"][req_id])
+ setattr(tool_output, "tool_error", tracking["tool_error"][req_id])
+ setattr(tool_output, "tool_output", tracking["tool_output"][req_id])
+ setattr(tool_output, "tool_runtime", tracking["tool_runtime"][req_id])
+ setattr(tool_output, "tool_called", tracking["tool_called"][req_id])
+ # Replace the output with the tool-processed one
+ output.outputs[0] = tool_output
return _process_outputs_with_tools(
final_outputs, dataset_index=dataset_index, token_statistics=token_statistics, start_time=start_time
@@ -317,6 +311,36 @@ def init_process_group(
return pg
+def add_request(request: PromptRequest, llm_engine: vllm.LLMEngine, tools, request_metadata: dict = None):
+ """Add a request to the LLM engine."""
+ prefix = "eval" if request.is_eval else "train"
+ request_id = f"{prefix}_{request.training_step}_{request.dataset_index}"
+ metadata = {
+ "is_eval": request.is_eval,
+ "dataset_index": request.dataset_index,
+ "training_step": request.training_step,
+ "sampling_params": request.generation_config,
+ }
+
+ tokens_prompt = vllm.TokensPrompt(prompt_token_ids=request.prompt, cache_salt=request_id)
+
+ # We *have* to manually duplicate requests to properly handle tool tracking,
+ # so we always do it to only have one code path.
+ # Create sampling params with n=1 for individual tracking
+ sampling_params = copy.deepcopy(request.generation_config)
+ sampling_params.n = 1
+ metadata["sampling_params"] = sampling_params
+ metadata["generation_config"] = request.generation_config
+ request_metadata[request_id] = metadata
+ for j in range(request.generation_config.n):
+ sub_request_id = f"{request_id}-{j}"
+ if request.generation_config.seed is not None:
+ # We need to seed each sub-request differently to avoid getting the same output.
+ sampling_params.seed = request.generation_config.seed + j
+ llm_engine.add_request(sub_request_id, tokens_prompt, sampling_params)
+ request_metadata[sub_request_id] = metadata
+
+
class LLMRayActor:
"""Ray actor for LLM generation with optional tool support."""
@@ -330,11 +354,14 @@ def __init__(
results_queue=None,
eval_results_queue=None,
actor_manager=None,
+ inference_batch_size: Optional[int] = None,
+ inflight_updates: bool = False,
**kwargs,
):
self.logger = logger_utils.setup_logger(__name__)
self.tools = tools or {}
self.max_tool_calls = max_tool_calls or {}
+ self.inflight_updates = inflight_updates
self.request_metadata = {}
if self.tools:
@@ -367,150 +394,255 @@ def __init__(
self.results_queue = results_queue
self.eval_results_queue = eval_results_queue
self.actor_manager = actor_manager
+ if inference_batch_size is None:
+ raise ValueError("inference_batch_size must be specified.")
+ self.inference_batch_size = inference_batch_size
+ self.request_metadata = {}
+
+ # For caching should_stop status.
+ self._last_should_stop_update = None
+ self._should_stop_value = None
+ self._should_stop_timeout_s = 5
+
+ def _should_stop(self) -> bool:
+ last_update = self._last_should_stop_update
+ if last_update is None or (time.perf_counter() - last_update) > self._should_stop_timeout_s:
+ should_stop_ref = self.actor_manager.should_stop.remote()
+ ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1)
+ if ready_refs:
+ should_stop = ray.get(ready_refs[0])
+ else:
+ ray.cancel(should_stop_ref)
+ should_stop = False
+ self._last_should_stop_update = time.perf_counter()
+ self._should_stop_value = should_stop
+
+ return self._should_stop_value
+
+ def _insert_result_to_queue(self, result, is_eval: bool):
+ """Insert result into the appropriate queue with error handling."""
+ try:
+ results_queue = self.eval_results_queue if is_eval else self.results_queue
+ results_queue.put(result, timeout=10)
+ except queue.Full:
+ queue_name = "eval" if is_eval else "train"
+ self.logger.warning(f"{queue_name} results queue is full, discarding result.")
+
+ def _maybe_add_new_request(self):
+ """Try to add a new request from the prompt queue if not stopping.
+
+ Returns:
+ bool: True if a request was added, False otherwise.
+ """
+ if not self._should_stop():
+ try:
+ request = self.prompt_queue.get_nowait()
+ self.logger.debug(
+ f"[_maybe_add_new_request] Added new request during processing: "
+ f"is_eval={request.is_eval}, dataset_index={request.dataset_index}, "
+ f"training_step={request.training_step}"
+ )
+ add_request(request, self.llm_engine, self.tools, request_metadata=self.request_metadata)
+ return True
+ except queue.Empty:
+ pass # No new request available, continue processing
+ else:
+ self.logger.debug("[_maybe_add_new_request] Skipping due to should_stop signal")
+ return False
def process_from_queue(self, timeout: float = 60.0):
"""Run generation loop using LLMEngine directly, with optional tool support.
Returns:
- int: Number of requests processed (0 or 1)
+ int: Number of requests processed.
"""
- while True:
- # Non-blocking check for should_stop using ray.wait
- should_stop_ref = self.actor_manager.should_stop.remote()
- ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1)
- if ready_refs and ray.get(ready_refs[0]):
- return 0
+ self.logger.info(f"[process_from_queue] Starting with inference_batch_size={self.inference_batch_size}")
+ num_processed = 0
+ overall_start_time = time.perf_counter()
+
+ tracking = _init_tool_tracking() if self.tools else None
+ tokenizer = self.llm_engine.tokenizer if self.tools else None
+
+ collected_outputs = defaultdict(list)
+ if self._should_stop():
+ self.logger.info("[process_from_queue] Early exit due to should_stop signal")
+ return num_processed
+
+ # Initial batch loading
+ batch_load_start = time.perf_counter()
+ initial_requests = 0
+ while initial_requests < self.inference_batch_size:
try:
request = self.prompt_queue.get(timeout=timeout)
+ self.logger.debug(
+ f"[process_from_queue] Got request from queue: "
+ f"is_eval={request.is_eval}, dataset_index={request.dataset_index}, "
+ f"training_step={request.training_step}"
+ )
+ add_request(request, self.llm_engine, self.tools, request_metadata=self.request_metadata)
+ initial_requests += 1
except queue.Empty:
- return 0
+ # If we couldn't get a request quickly and have some requests, start processing
+ if self.llm_engine.has_unfinished_requests():
+ self.logger.debug(
+ f"[process_from_queue] Queue empty after {initial_requests} requests, starting processing"
+ )
+ break
+ # Otherwise continue trying to get more requests
- result = self._process_request(request)
+ batch_load_time = time.perf_counter() - batch_load_start
+ self.logger.info(
+ f"[process_from_queue] Initial batch loaded {initial_requests} requests in {batch_load_time:.2f}s"
+ )
- try:
- if request.is_eval:
- self.eval_results_queue.put(result, timeout=10)
- else:
- self.results_queue.put(result, timeout=10)
- return 1 # Successfully processed one request
- except queue.Full:
- self.logger.warning("Results queue is full, discarding result.")
- return 0
-
- def _process_request(self, request):
- """Unified processing for both tool and non-tool generation."""
- prompts = request.prompts
- sampling_params = request.generation_config
- start_time = request.start_time
-
- self.logger.info(f"[LLMRayActor] Processing request with {len(prompts)} prompts, tools={bool(self.tools)}")
+ # Main processing loop
+ loop_iteration = 0
+ # Timing accumulators for every 100 iterations
+ step_engine_time_acc = 0.0
+ output_processing_time_acc = 0.0
+ loop_block_start = time.perf_counter()
- if self.tools:
- # Need n=1 for individual tool tracking
- sampling_params = copy.deepcopy(sampling_params)
- original_n = request.generation_config.n
- sampling_params.n = 1
- tracking = _init_tool_tracking()
- tokenizer = self.llm_engine.tokenizer
- else:
- original_n = 1
- tracking = None
- tokenizer = None
+ while True:
+ loop_iteration += 1
- self._add_initial_requests(prompts, sampling_params, original_n, request.training_step)
+ if self._should_stop() and self.inflight_updates:
+ self.logger.info(
+ f"[process_from_queue] Stopping due to should_stop signal (inflight_updates=True) after {loop_iteration} iterations"
+ )
+ return num_processed
+ step_start = time.perf_counter()
+ outputs = self._step_engine(tracking, tokenizer)
+ step_engine_time = time.perf_counter() - step_start
+ step_engine_time_acc += step_engine_time
+
+ self.logger.debug(
+ f"[process_from_queue] Loop iteration {loop_iteration}: got {len(outputs)} outputs from step_engine"
+ )
+
+ output_start = time.perf_counter()
+ for output in outputs:
+ request_id = output.request_id.split("-")[0]
+ collected_outputs[request_id].append(output)
+ metadata = self.request_metadata[request_id]
+
+ self.logger.debug(
+ f"[process_from_queue] Collected output for {request_id}: "
+ f"{len(collected_outputs[request_id])}/{metadata['generation_config'].n} outputs"
+ )
+
+ if len(collected_outputs[request_id]) != metadata["generation_config"].n:
+ continue
+
+ outputs_to_finalize = collected_outputs[request_id]
+ num_processed += 1
+ self.logger.debug(
+ f"[process_from_queue] Finalizing outputs for {request_id}, total processed: {num_processed}"
+ )
+
+ result = _finalize_outputs(outputs_to_finalize, tracking, metadata["dataset_index"], self.tools)
+
+ self._insert_result_to_queue(result, metadata["is_eval"])
+ del collected_outputs[request_id]
+ self.request_metadata.pop(request_id, None)
+ for i in range(metadata["generation_config"].n):
+ self.request_metadata.pop(f"{request_id}-{i}", None)
+ output_processing_time = time.perf_counter() - output_start
+ output_processing_time_acc += output_processing_time
+
+ unfinished = self.llm_engine.has_unfinished_requests()
+ if self._should_stop() and not self.inflight_updates:
+ pending_tool_futures = tracking["pending_tool_futures"] if self.tools else {}
+ if not unfinished and not pending_tool_futures:
+ total_time = time.perf_counter() - overall_start_time
+ self.logger.info(
+ f"[process_from_queue] Stopping: no unfinished requests or pending tools, "
+ f"processed {num_processed} requests in {loop_iteration} iterations, total_time={total_time:.2f}s"
+ )
+ break
+
+ # Log timing summary every 100 iterations
+ if loop_iteration % 100 == 0:
+ loop_block_time = time.perf_counter() - loop_block_start
+ self.logger.info(
+ f"[process_from_queue] Timing (iters {loop_iteration - 99}-{loop_iteration}): "
+ f"total={loop_block_time:.2f}s, "
+ f"step_engine={step_engine_time_acc:.2f}s ({step_engine_time_acc / loop_block_time * 100:.1f}%), "
+ f"output_processing={output_processing_time_acc:.2f}s ({output_processing_time_acc / loop_block_time * 100:.1f}%), "
+ f"processed={num_processed}"
+ )
+ # Reset accumulators
+ step_engine_time_acc = 0.0
+ output_processing_time_acc = 0.0
+ loop_block_start = time.perf_counter()
+
+ total_time = time.perf_counter() - overall_start_time
+ self.logger.info(
+ f"[process_from_queue] Completed: processed {num_processed} requests in {total_time:.2f}s (avg {total_time / num_processed:.2f}s per request)"
+ if num_processed > 0
+ else "[process_from_queue] Completed: no requests processed"
+ )
+ return num_processed
+
+ def _step_engine(self, tracking, tokenizer):
+ """Unified processing for both tool and non-tool generation.
+
+ Returns:
+ List of completed outputs.
+ """
outputs = []
- iteration = 0
- while True:
- iteration += 1
-
- # Poll tool futures first (matching ToolUseLLM order)
- if tracking and tracking.get("pending_tool_futures"):
- self._poll_tool_futures(tracking, sampling_params, tokenizer)
-
- # Process engine steps - ONLY if there are unfinished requests (matching ToolUseLLM)
- if self.llm_engine.has_unfinished_requests():
- step_outputs = list(self.llm_engine.step())
- for output in step_outputs:
- if output.finished:
- result = _handle_output(
- output, self.tools, tracking, sampling_params, self.max_tool_calls, self.executor
- )
- if result is not None:
- outputs.append(result)
-
- # Check termination condition (matching ToolUseLLM exactly)
- pending_count = len(tracking["pending_tool_futures"]) if tracking else 0
- if not self.llm_engine.has_unfinished_requests() and pending_count == 0:
- self.logger.info(f"[LLMRayActor] Terminating after {iteration} iterations with {len(outputs)} outputs")
+ if self.tools and tracking["pending_tool_futures"]:
+ tool_outputs = self._poll_tool_futures(tracking, tokenizer)
+ outputs.extend(tool_outputs)
+ if tool_outputs:
+ self.logger.debug(f"[_step_engine] Got {len(tool_outputs)} outputs from tool futures")
+
+ if self.llm_engine.has_unfinished_requests():
+ num_unfinished = self.llm_engine.get_num_unfinished_requests()
+ self.logger.debug(f"[_step_engine] Stepping engine with {num_unfinished} unfinished requests")
+ step_outputs = list(self.llm_engine.step())
+ self.logger.debug(f"[_step_engine] Engine step returned {len(step_outputs)} outputs")
+
+ for output in step_outputs:
+ if output.finished:
+ sampling_params = self.request_metadata[output.request_id]["sampling_params"]
+ result = _handle_output(
+ output, self.tools, tracking, sampling_params, self.max_tool_calls, self.executor
+ )
+ if result is not None:
+ outputs.append(result)
+ self.logger.debug(f"[_step_engine] Finished output for request {output.request_id}")
+
+ # Try to keep the engine full
+ added_count = 0
+ while not self.llm_engine.get_num_unfinished_requests() < self.inference_batch_size:
+ if not self._maybe_add_new_request():
break
+ added_count += 1
- end_time = time.time()
- total_prompt_tokens = 0
- total_generation_tokens = 0
- earliest_start_time = float("inf")
+ return outputs
- for output in outputs:
- request_id = output.request_id
- if request_id in self.request_metadata:
- metadata = self.request_metadata[request_id]
- total_prompt_tokens += metadata["prompt_tokens"]
- earliest_start_time = min(earliest_start_time, metadata["start_time"])
-
- for completion in output.outputs:
- total_generation_tokens += len(completion.token_ids)
-
- generation_time = end_time - earliest_start_time
-
- for output in outputs:
- self.request_metadata.pop(output.request_id, None)
-
- result = _finalize_outputs(
- outputs,
- tracking,
- request.dataset_index,
- self.tools,
- token_statistics=TokenStatistics(
- num_prompt_tokens=total_prompt_tokens,
- num_response_tokens=total_generation_tokens,
- generation_time=generation_time,
- ),
- start_time=start_time,
- )
- return result
-
- def _add_initial_requests(self, prompts, sampling_params, n_samples, training_step):
- """Add initial requests to the engine."""
- for i, prompt in enumerate(prompts):
- if self.tools:
- # Create individual requests for each sample when using tools
- for j in range(n_samples):
- request_id = f"{training_step}_{i}-{j}"
- self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)}
- tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=f"{training_step}_{i}")
- self.llm_engine.add_request(request_id, tokens_prompt, sampling_params)
- else:
- # Standard request format for non-tool mode
- request_id = f"batch_{training_step}_{i}"
- self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)}
- tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id)
- self.llm_engine.add_request(request_id, tokens_prompt, sampling_params)
-
- def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
- """Poll and handle completed tool executions."""
+
+ def _poll_tool_futures(self, tracking, tokenizer):
+ """Poll and handle completed tool executions.
+
+ Returns:
+ List of completed outputs that can't continue generation.
+ """
if not self.tools or not tracking["pending_tool_futures"]:
- return
+ return []
dict_keys_to_delete = []
+ completed_outputs = []
for req_id, (future, last_o, last_output) in tracking["pending_tool_futures"].items():
if not future.done():
continue
- # Tool future is done, process it
- tool_result = future.result() # Get the tool result
+ # Tool future is done, process it.
+ tool_result = future.result()
last_prompt_token_ids = last_output.prompt_token_ids
last_token_ids = last_o.token_ids
@@ -534,34 +666,47 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
can_make_new_request = True
# Edge case 2: clip against per-request max_tokens
- remaining = sampling_params.max_tokens - len(tracking["masks"][req_id])
- if remaining <= 0:
- tool_output_token_ids = []
- elif len(tool_output_token_ids) > remaining:
- tool_output_token_ids = tool_output_token_ids[:remaining]
+ req_sampling_params = self.request_metadata[req_id]["sampling_params"]
+ if req_sampling_params:
+ remaining = req_sampling_params.max_tokens - len(tracking["masks"][req_id])
+ if remaining <= 0:
+ tool_output_token_ids = []
+ elif len(tool_output_token_ids) > remaining:
+ tool_output_token_ids = tool_output_token_ids[:remaining]
tracking["concat_outputs"][req_id].outputs[0].token_ids.extend(tool_output_token_ids)
tracking["masks"][req_id].extend([0] * len(tool_output_token_ids))
- new_sample_tokens = sampling_params.max_tokens - len(tracking["masks"][req_id])
- can_make_new_request = can_make_new_request and new_sample_tokens > 0
+
+ if req_sampling_params:
+ new_sample_tokens = req_sampling_params.max_tokens - len(tracking["masks"][req_id])
+ can_make_new_request = can_make_new_request and new_sample_tokens > 0
+ else:
+ new_sample_tokens = 0
+ can_make_new_request = False
if can_make_new_request:
- new_sampling_params = copy.deepcopy(sampling_params)
+ new_sampling_params = copy.deepcopy(req_sampling_params)
new_sampling_params.max_tokens = new_sample_tokens
try:
- self.llm_engine.add_request(
- req_id, vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output_token), new_sampling_params
- )
+ prompt = vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output_token, cache_salt=req_id)
+ self.llm_engine.add_request(req_id, prompt, new_sampling_params)
+ # Update the sampling params in request_metadata for the restarted request
+ if req_id in self.request_metadata:
+ self.request_metadata[req_id]["sampling_params"] = new_sampling_params
except Exception as e:
# Match original ToolUseLLM behavior - just log and continue
self.logger.error(f"[_poll_tool_futures] Error adding request {req_id}: {e}")
+ completed_outputs.append(tracking["concat_outputs"][req_id])
+ else:
+ completed_outputs.append(tracking["concat_outputs"][req_id])
dict_keys_to_delete.append(req_id)
for req_id in dict_keys_to_delete:
- if req_id in tracking["pending_tool_futures"]:
- del tracking["pending_tool_futures"][req_id]
+ tracking["pending_tool_futures"].pop(req_id, None)
+
+ return completed_outputs
def init_process_group(
self,
@@ -662,6 +807,7 @@ def create_vllm_engines(
max_model_len: int,
vllm_gpu_memory_utilization: float = 0.9,
single_gpu_mode: bool = False,
+ inference_batch_size: Optional[int] = None,
pg: Optional[ray.util.placement_group] = None,
vllm_enable_sleep=False,
tools: Optional[Dict[str, Tool]] = None,
@@ -670,6 +816,7 @@ def create_vllm_engines(
results_queue=None,
eval_results_queue=None,
actor_manager=None,
+ inflight_updates: bool = False,
) -> list[LLMRayActor]:
# Convert max_tool_calls to a dict mapping tool end strings to their limits
if tools:
@@ -750,6 +897,8 @@ def create_vllm_engines(
actor_manager=actor_manager,
tools=tools,
max_tool_calls=max_tool_calls_dict,
+ inference_batch_size=inference_batch_size,
+ inflight_updates=inflight_updates,
)
)
diff --git a/scripts/gantry_run_benchmark.sh b/scripts/gantry_run_benchmark.sh
index 135ae62d9..abcca82cf 100755
--- a/scripts/gantry_run_benchmark.sh
+++ b/scripts/gantry_run_benchmark.sh
@@ -15,7 +15,7 @@ fi
git_hash=$(git rev-parse --short HEAD)
git_branch=$(git rev-parse --abbrev-ref HEAD)
-model_name_or_path="hamishivi/qwen2_5_openthoughts2" \
+model_name_or_path="hamishivi/qwen2_5_openthoughts2"
gantry run \
--name open_instruct-benchmark_generators \
diff --git a/scripts/train/debug/large_test_script.sh b/scripts/train/debug/large_test_script.sh
index da349efcc..326a04370 100755
--- a/scripts/train/debug/large_test_script.sh
+++ b/scripts/train/debug/large_test_script.sh
@@ -26,6 +26,7 @@ uv run python mason.py \
--learning_rate 5e-7 \
--per_device_train_batch_size 1 \
--kl_estimator kl3 \
+ --verbose True \
--dataset_mixer_list saurabh5/rlvr_acecoder_filtered ${num_prompts} saurabh5/open-code-reasoning-rlvr-stdio ${num_prompts} \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list saurabh5/rlvr_acecoder_filtered 8 saurabh5/open-code-reasoning-rlvr-stdio 8 \
@@ -34,7 +35,7 @@ uv run python mason.py \
--max_prompt_token_length 2048 \
--response_length 4096 \
--pack_length 20480 \
- --model_name_or_path Qwen/Qwen2.5-7B \
+ --model_name_or_path hamishivi/qwen2_5_openthoughts2 \
--chat_template_name tulu_thinker \
--stop_strings "" \
--non_stop_penalty False \
@@ -45,6 +46,8 @@ uv run python mason.py \
--total_episodes 10_000 \
--deepspeed_stage 2 \
--num_learners_per_node 8 \
+ --inflight_updates True \
+ --update_every 1 \
--vllm_num_engines 8 \
--vllm_tensor_parallel_size 1 \
--lr_scheduler_type constant \
diff --git a/scripts/train/debug/single_gpu_on_beaker.sh b/scripts/train/debug/single_gpu_on_beaker.sh
index dfeb6ec33..9db5af0f2 100755
--- a/scripts/train/debug/single_gpu_on_beaker.sh
+++ b/scripts/train/debug/single_gpu_on_beaker.sh
@@ -37,6 +37,7 @@ uv run python mason.py \
--apply_r1_style_format_reward \
--apply_verifiable_reward true \
--temperature 0.7 \
+ --inflight_updates True \
--ground_truths_key ground_truth \
--chat_template_name r1_simple_chat_postpend_think \
--learning_rate 3e-7 \
@@ -49,6 +50,7 @@ uv run python mason.py \
--beta 0.01 \
--seed 3 \
--local_eval_every 1 \
+ --vllm_enable_prefix_caching True \
--vllm_sync_backend gloo \
--vllm_gpu_memory_utilization 0.3 \
--save_traces \