diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index c81ec9996..ec94e4cc2 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -127,13 +127,27 @@ def get_git_commit() -> str: def save_benchmark_results_to_csv( - results: list[dict[str, Any]], total_time: float, args: grpo_fast.Args, model_config: model_utils.ModelConfig + results: list[dict[str, Any]], + total_time: float, + overall_tokens_per_second: float, + args: grpo_fast.Args, + model_config: model_utils.ModelConfig, ) -> None: """Save benchmark results to CSV file.""" git_commit = get_git_commit() agg_results = aggregate_results(results) csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv" + device_name = get_device_name(torch.cuda.get_device_name(0)) + device_flops = GPU_SPECS[device_name]["flops"] + model_dims = load_model_dims(model_config.model_name_or_path) + all_prompt_lengths = [result["prompt_lengths"] for result in results] + all_response_lengths = [result["response_lengths"] for result in results] + + total_flops = model_dims.flops( + all_prompt_lengths, all_response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout + ) + row_data = { "git_commit": git_commit, "model": model_config.model_name_or_path, @@ -152,6 +166,7 @@ def save_benchmark_results_to_csv( "avg_generation_time_per_batch": agg_results["avg_generation_time"], "avg_new_tokens_per_sample": agg_results["total_num_new_tokens"] / (len(results) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout), + "overall_mfu": 100 * (total_flops / total_time) / device_flops, } csv_path: pathlib.Path = DATA_DIR / "generator_benchmark_results.csv" @@ -591,7 +606,7 @@ def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation def setup_vllm_engines( - args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480 + args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480, num_batches: int = 5 ) -> tuple[list[ray.actor.ActorHandle], ray_queue.Queue, ray_queue.Queue]: """Set up vLLM engines and queues.""" logger.info("Setting up vLLM engines...") @@ -605,8 +620,12 @@ def setup_vllm_engines( pg = ray.util.placement_group(bundles, strategy="PACK") ray.get(pg.ready()) - param_prompt_Q = ray_queue.Queue(maxsize=10) - inference_results_Q = ray_queue.Queue(maxsize=10) + # Queue size needs to accommodate all individual prompts across all batches. + # Each batch has num_unique_prompts_rollout prompts, and we submit them individually. + # Total individual prompts = num_unique_prompts_rollout * num_batches + queue_size = args.num_unique_prompts_rollout * num_batches + param_prompt_Q = ray_queue.Queue(maxsize=queue_size) + inference_results_Q = ray_queue.Queue(maxsize=queue_size) queues_to_monitor = {"Param Prompt Queue": param_prompt_Q, "Inference Results Queue": inference_results_Q} actor_manager = ray.remote(ActorManager).remote(queues_to_monitor, args) @@ -623,6 +642,7 @@ def setup_vllm_engines( max_model_len=max_model_len, vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization, single_gpu_mode=False, + inference_batch_size=args.inference_batch_size, pg=pg, tools={}, max_tool_calls=[0], @@ -640,7 +660,11 @@ def generate_thread(vllm_engines: list[ray.actor.ActorHandle], stop_event: threa """Thread that repeatedly calls process_from_queue on vllm engines.""" logger.info("[Generate Thread] Starting generation thread") while not stop_event.is_set(): - processed_results = ray.get([engine.process_from_queue.remote(timeout=20) for engine in vllm_engines]) + processed_results = utils.ray_get_with_progress( + [engine.process_from_queue.remote(timeout=20) for engine in vllm_engines], + "[Generate Thread] Waiting for engines to process prompts", + enable=False, + ) num_processed = sum(int(result) for result in processed_results) if num_processed == 0: time.sleep(1) @@ -657,31 +681,30 @@ def submission_thread( start_batch_idx: int, num_batches: int, ) -> None: - """Thread that submits prompts to the queue.""" + """Thread that submits individual prompts to the queue.""" logger.info("[Submission Thread] Starting prompt submission") - for batch_idx in range(start_batch_idx, start_batch_idx + num_batches): - if stop_event.is_set(): - logger.info("[Submission Thread] Stopped due to stop event") - break - - # Get batch data from dataset - start_idx = batch_idx * batch_size - end_idx = min(start_idx + batch_size, len(dataset)) - batch_data = dataset[start_idx:end_idx] - prompts = batch_data[dataset_transformation.INPUT_IDS_PROMPT_KEY] - - # Create list of dataset indices for this batch - dataset_indices = list(range(start_idx, end_idx)) - - param_prompt_Q.put( - PromptRequest( - prompts=prompts, - dataset_index=dataset_indices, - generation_config=generation_config, - start_time=time.perf_counter(), + total_prompts_submitted = 0 + for batch_idx in range(num_batches): + for prompt_idx in range(batch_size): + dataset_idx = start_batch_idx * batch_size + batch_idx * batch_size + prompt_idx + if dataset_idx >= len(dataset): + break + + prompt = dataset[dataset_idx][dataset_transformation.INPUT_IDS_PROMPT_KEY] + actual_batch_idx = start_batch_idx + batch_idx + actual_dataset_index = dataset_idx + param_prompt_Q.put( + PromptRequest( + prompt=prompt, + generation_config=generation_config, + training_step=actual_batch_idx, + dataset_index=actual_dataset_index, + is_eval=False, + start_time=time.perf_counter(), + ) ) - ) - logger.info(f"[Submission Thread] All {num_batches} batches submitted") + total_prompts_submitted += 1 + logger.info(f"[Submission Thread] All {total_prompts_submitted} individual prompts submitted") def run_benchmark( @@ -713,37 +736,37 @@ def run_benchmark( ) stop_event = threading.Event() - executor = futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="benchmark") + executor = futures.ThreadPoolExecutor(max_workers=len(vllm_engines) + 1, thread_name_prefix="benchmark") generation_future = executor.submit(generate_thread, vllm_engines, stop_event) + submission_future = None # Initialize to None for access in finally block results = [] device_name = get_device_name(torch.cuda.get_device_name(0)) device_flops = GPU_SPECS[device_name]["flops"] device_memory_bandwidth = GPU_SPECS[device_name]["memory_bandwidth"] - # Submit warmup batch first + # Load model dims for FLOPS calculations + model_dims = load_model_dims(model_config.model_name_or_path) + logger.info("Submitting warmup batch...") - warmup_start_idx = 0 - warmup_end_idx = min(args.num_unique_prompts_rollout, len(dataset)) - warmup_data = dataset[warmup_start_idx:warmup_end_idx] - warmup_prompts = warmup_data[dataset_transformation.INPUT_IDS_PROMPT_KEY] - warmup_dataset_indices = list(range(warmup_start_idx, warmup_end_idx)) - param_prompt_Q.put( - PromptRequest( - prompts=warmup_prompts, - dataset_index=warmup_dataset_indices, - generation_config=generation_config, - start_time=time.perf_counter(), + for prompt_idx in range(args.num_unique_prompts_rollout): + param_prompt_Q.put( + PromptRequest( + prompt=dataset[prompt_idx][dataset_transformation.INPUT_IDS_PROMPT_KEY], + generation_config=generation_config, + training_step=0, # warmup is training step 0 + dataset_index=prompt_idx, + is_eval=False, + ) ) - ) - model_dims = load_model_dims(model_config.model_name_or_path) try: logger.info("Running warmup batch...") - warmup_result = inference_results_Q.get() - logger.info(f"Warmup batch completed with {len(warmup_result.responses)} responses") + for _ in range(args.num_unique_prompts_rollout): + inference_results_Q.get() + logger.info("Warmup batch completed") logger.info(f"Submitting {num_batches - 1} batches for main benchmark...") submission_future = executor.submit( submission_thread, @@ -755,41 +778,51 @@ def run_benchmark( 1, num_batches - 1, ) - # Process remaining batches with timing + + main_benchmark_start_time = time.perf_counter() + for batch_idx in range(1, num_batches): - # Quick health check! [future.result() for future in [submission_future, generation_future] if future.done()] - result = inference_results_Q.get() + + batch_results = [] + batch_start_time = time.perf_counter() + for _ in range(args.num_unique_prompts_rollout): + batch_results.append(inference_results_Q.get()) completion_time = time.perf_counter() - # Calculate generation time from when the request was enqueued - batch_generation_time = completion_time - result.start_time if result.start_time else 0 + batch_generation_time = completion_time - batch_start_time + + total_new_tokens = 0 + all_response_lengths = [] + all_finish_reasons = [] - new_tokens = sum(len(response) for response in result.responses) - tokens_per_second = new_tokens / batch_generation_time if batch_generation_time > 0 else 0 + for i, result in enumerate(batch_results): + result_tokens = sum(len(response) for response in result.responses) + total_new_tokens += result_tokens + all_response_lengths.extend([len(response) for response in result.responses]) + all_finish_reasons.extend(result.finish_reasons) + tokens_per_second = total_new_tokens / batch_generation_time if batch_generation_time > 0 else 0 result_dict = { "tokens_per_second": tokens_per_second, "generation_time": batch_generation_time, - "num_new_tokens": new_tokens, - "finish_reasons": collections.Counter(result.finish_reasons), - "response_lengths": [len(response) for response in result.responses], - "dataset_indices": result.dataset_index, + "num_new_tokens": total_new_tokens, + "finish_reasons": collections.Counter(all_finish_reasons), + "response_lengths": all_response_lengths, + "dataset_indices": [r.dataset_index for r in batch_results], } - # Get prompt lengths using dataset indices from the result - prompt_data = dataset[result.dataset_index] - prompts = prompt_data[dataset_transformation.INPUT_IDS_PROMPT_KEY] - prompt_lengths = [len(prompt) for prompt in prompts] - response_lengths = [len(response) for response in result.responses] - - # Calculate total FLOPs for all prompts and responses in the batch - # No need to expand prompt_lengths - the flops method now handles samples_per_prompt + prompt_lengths = [] + response_lengths = all_response_lengths + for r in batch_results: + prompt_data = dataset[r.dataset_index] + prompt = prompt_data[dataset_transformation.INPUT_IDS_PROMPT_KEY] + prompt_lengths.append(len(prompt)) + + result_dict["prompt_lengths"] = prompt_lengths model_flops = model_dims.flops( prompt_lengths, response_lengths, samples_per_prompt=args.num_samples_per_prompt_rollout ) - - # MFU = (FLOPs / time) / peak_FLOPS * 100 - model_flops_per_second = model_flops / batch_generation_time if batch_generation_time > 0 else 0 + model_flops_per_second = model_flops / batch_generation_time result_dict["mfu"] = 100 * model_flops_per_second / device_flops # Calculate total memory bytes for all prompts and responses in the batch @@ -809,18 +842,36 @@ def run_benchmark( f"MFU: {result_dict['mfu']:.2f}%, " f"MBU: {result_dict['mbu']:.2f}%, " f"generation time: {batch_generation_time:.2f}s, " - f"total new tokens: {new_tokens}" + f"total new tokens: {total_new_tokens}" ) - # Calculate total time for main benchmark only - main_benchmark_time = sum(r["generation_time"] for r in results) + main_benchmark_end_time = time.time() + main_benchmark_total_time = main_benchmark_end_time - main_benchmark_start_time + + total_main_tokens = sum(r["num_new_tokens"] for r in results) + overall_tokens_per_second = ( + total_main_tokens / main_benchmark_total_time if main_benchmark_total_time > 0 else 0 + ) - print_summary(results, main_benchmark_time, args, model_config) - save_benchmark_results_to_csv(results, main_benchmark_time, args, model_config) + logger.info("\nOverall main benchmark performance:") + logger.info(f" Total wall-clock time: {main_benchmark_total_time:.2f}s") + logger.info(f" Total tokens generated: {total_main_tokens}") + logger.info(f" Overall tokens/second: {overall_tokens_per_second:.2f}") + print_summary(results, main_benchmark_total_time, overall_tokens_per_second, args, model_config) + save_benchmark_results_to_csv( + results, main_benchmark_total_time, overall_tokens_per_second, args, model_config + ) finally: + logger.info("Starting cleanup...") stop_event.set() - executor.shutdown(wait=True) + + logger.info("Waiting for threads to complete...") + try: + submission_future.result(timeout=5) + generation_future.result(timeout=10) + finally: + executor.shutdown(wait=True) logger.info("Threads cleaned up") @@ -867,7 +918,11 @@ def aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]: def print_summary( - results: list[dict[str, Any]], total_time: float, args: grpo_fast.Args, model_config: model_utils.ModelConfig + results: list[dict[str, Any]], + total_time: float, + overall_tokens_per_second: float, + args: grpo_fast.Args, + model_config: model_utils.ModelConfig, ) -> None: """Print benchmark summary statistics.""" @@ -885,8 +940,11 @@ def print_summary( print(f"Num rollouts: {args.num_samples_per_prompt_rollout}") print(f"Max tokens: {args.response_length}") print("-" * 60) + print(f"Total wall-clock time (main benchmark): {total_time:.2f}s") + print(f"Total new tokens generated: {agg_results['total_num_new_tokens']}") print(f"Total time (main benchmark): {agg_results['total_generation_time']:.2f}s") print(f"Total new tokens generated: {agg_results['total_num_new_tokens']}") + print(f"Overall tokens/second: {overall_tokens_per_second:.2f}") print("-" * 60) print(f"Average results over {len(results)} main benchmark batches:") print(f"Average tokens/second: {agg_results['avg_tokens_per_second']:.2f}") diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 2ff67be1d..0f71fb8ff 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -202,6 +202,10 @@ class Args: fused_optimizer: bool = False """Whether to use fused optimizer""" + # Progress reporting + update_every: int = 10 + """How often to update the Beaker description progress bar.""" + # Batch sizes per_device_train_batch_size: int = 1 """The forward batch size per device (local_micro_batch_size)""" @@ -329,6 +333,8 @@ class Args: on the first node and 4 learner processes on the second node; each process will have 1 GPU)""" vllm_num_engines: int = 1 """number of vLLM Engines, set to 0 to disable vLLM""" + inflight_updates: bool = False + """If True, return immediately even with pending work. If False, wait for all work to complete before exiting.""" inference_batch_size: Optional[int] = None """inference batch size per vLLM engine. If None, calculated as ceil(num_unique_prompts_rollout / vllm_num_engines) * num_samples_per_prompt_rollout""" vllm_tensor_parallel_size: int = 1 @@ -343,6 +349,8 @@ class Args: """whether to enable prefix caching""" vllm_top_p: float = 1.0 """vLLM top p for nucleus sampling""" + inference_batch_size: Optional[int] = None + """Number of inference requests to batch together for vLLM processing""" deepspeed_stage: int = 0 """the deepspeed stage""" gather_whole_model: bool = True @@ -433,6 +441,7 @@ def __post_init__(self): assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" if self.num_samples_per_prompt_rollout == 1: logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") + assert self.apply_verifiable_reward or self.apply_r1_style_format_reward or self.non_stop_penalty, ( "At least one reward must be applied!" ) @@ -444,6 +453,8 @@ def __post_init__(self): # Initialize stop_strings if None if self.stop_strings is None: self.stop_strings = [] + if self.inference_batch_size is None: + self.inference_batch_size = self.num_unique_prompts_rollout // self.vllm_num_engines assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" ) @@ -534,6 +545,87 @@ def to_device_inplace(tensors_list: List[torch.Tensor], device: torch.device): tensors_list[i] = tensors_list[i].to(device, non_blocking=True) +def log_memory_usage( + rank: int, device: torch.device, stage: str, include_tensors: bool = False, top_k_tensors: int = 10 +): + """Log comprehensive memory usage information for debugging OOM errors. + + Args: + rank: Process rank + device: CUDA device + stage: String describing the current stage of execution + include_tensors: Whether to log information about individual tensors + top_k_tensors: Number of largest tensors to log + """ + import gc + + import psutil + + # System RAM usage + process = psutil.Process() + ram_usage_gb = process.memory_info().rss / (1024**3) + ram_percent = process.memory_percent() + + # Get available system memory + vm = psutil.virtual_memory() + available_ram_gb = vm.available / (1024**3) + total_ram_gb = vm.total / (1024**3) + + logger.info(f"[Rank {rank}] Memory @ {stage}:") + logger.info( + f" System RAM: {ram_usage_gb:.2f}GB ({ram_percent:.1f}%) | Available: {available_ram_gb:.2f}GB / {total_ram_gb:.2f}GB" + ) + + if torch.cuda.is_available(): + # Force synchronization to get accurate memory stats + torch.cuda.synchronize(device) + + # GPU memory usage + allocated_gb = torch.cuda.memory_allocated(device) / (1024**3) + reserved_gb = torch.cuda.memory_reserved(device) / (1024**3) + max_allocated_gb = torch.cuda.max_memory_allocated(device) / (1024**3) + + # Get total GPU memory + total_gpu_memory = torch.cuda.get_device_properties(device).total_memory / (1024**3) + free_memory = (torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_reserved(device)) / ( + 1024**3 + ) + + logger.info(" GPU Memory:") + logger.info( + f" Allocated: {allocated_gb:.2f}GB / Reserved: {reserved_gb:.2f}GB / Total: {total_gpu_memory:.2f}GB" + ) + logger.info(f" Max Allocated: {max_allocated_gb:.2f}GB | Free: {free_memory:.2f}GB") + logger.info(f" Utilization: {(reserved_gb / total_gpu_memory * 100):.1f}%") + + # Memory summary from PyTorch + if hasattr(torch.cuda, "memory_summary"): + summary = torch.cuda.memory_summary(device, abbreviated=True) + logger.debug(f" PyTorch Memory Summary:\n{summary}") + + if include_tensors: + # Find and log largest tensors + tensor_sizes = [] + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) and obj.is_cuda and obj.device == device: + size_mb = float(obj.element_size() * obj.nelement() / (1024**2)) + tensor_sizes.append((size_mb, obj.shape, obj.dtype, obj.device)) + except Exception: + pass + + if tensor_sizes: + tensor_sizes.sort(reverse=True, key=lambda x: x[0]) + logger.info(f" Top {min(top_k_tensors, len(tensor_sizes))} Tensors on GPU:") + for i, (size_mb, shape, dtype, dev) in enumerate(tensor_sizes[:top_k_tensors]): + logger.info(f" {i + 1}. {size_mb:.2f}MB | Shape: {shape} | Dtype: {dtype}") + + total_tracked_mb = float(sum(x[0] for x in tensor_sizes)) + logger.info( + f" Total tracked tensor memory: {total_tracked_mb:.2f}MB ({total_tracked_mb / 1024:.2f}GB)" + ) + + class ShufflingIterator: def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): self.data = data.copy() @@ -630,6 +722,9 @@ def load(self, path: str, map_location=None): dschf = None logger.info(f"Deepspeed config: {dschf=}") + # Log memory before loading policy model + log_memory_usage(self.rank, self.device, "Before loading policy model", include_tensors=False) + self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( model_config.model_name_or_path, revision=model_config.model_revision, @@ -639,6 +734,9 @@ def load(self, path: str, map_location=None): ) disable_dropout_in_model(self.policy) self.policy.gradient_checkpointing_enable() + + # Log memory after loading policy model + log_memory_usage(self.rank, self.device, "After loading policy model", include_tensors=False) # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam # AdamOptimizer = FusedAdam if args.set_weight_decay_on_bias_and_norm: @@ -728,6 +826,9 @@ def load(self, path: str, map_location=None): dschf = None logger.info(f"DeepSpeed config: {dschf=}") + # Log memory before loading reference policy + log_memory_usage(self.rank, self.device, "Before loading ref_policy", include_tensors=False) + self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( model_config.model_name_or_path, revision=model_config.model_revision, @@ -739,6 +840,9 @@ def load(self, path: str, map_location=None): self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) self.ref_policy.eval() + # Log memory after loading reference policy + log_memory_usage(self.rank, self.device, "After loading ref_policy", include_tensors=True, top_k_tensors=10) + # Load reference policy checkpoint if available if hasattr(self, "ref_policy_checkpoint_path") and self.ref_policy_checkpoint_path: state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device) @@ -885,12 +989,19 @@ def train( num_mini_batches: int, ): args = self.args + + # Log memory before moving tensors to device + log_memory_usage(self.rank, self.device, "Before to_device", include_tensors=False) + to_device_inplace(collated_query_responses, self.device) to_device_inplace(collated_tool_masks, self.device) to_device_inplace(collated_attention_masks, self.device) to_device_inplace(collated_position_ids, self.device) to_device_inplace(collated_advantages, self.device) to_device_inplace(collated_response_masks, self.device) + + # Log memory after moving tensors to device + log_memory_usage(self.rank, self.device, "After to_device", include_tensors=True, top_k_tensors=5) # accumulation steps should always be at least 1 accumulation_steps = max(math.ceil(len(collated_query_responses) / num_mini_batches - 0.5), 1) leftover = len(collated_query_responses) % accumulation_steps @@ -963,6 +1074,9 @@ def train( old_logprobs[i] = old_logprob torch.cuda.empty_cache() + # Log memory before training loop + log_memory_usage(self.rank, self.device, "Before training loop", include_tensors=False) + local_step = 0 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): @@ -1038,7 +1152,23 @@ def train( # grpo change: directly subtract KL in loss (add) loss = masked_mean(pg_loss_max + (args.beta * kl), mb_response_masks_bool, args.masked_mean_axis) loss = loss / accumulation_steps + + # Log memory before backward pass + log_memory_usage( + self.rank, self.device, f"Before backward (step {local_step})", include_tensors=False + ) + self.model.backward(loss) + + # Log memory after backward pass + log_memory_usage( + self.rank, + self.device, + f"After backward (step {local_step})", + include_tensors=True, + top_k_tensors=10, + ) + if (local_step + 1) % accumulation_steps == 0: self.model.step() local_step += 1 @@ -1344,17 +1474,19 @@ def accumulate_inference_batches( args: Args, training_step: int, generation_config, + num_prompts: int, actor_manager=None, timeout: Optional[float] = None, ) -> tuple[GenerationResult, Batch]: """Accumulate multiple inference results into a single training batch. Args: - inference_results_Q: Queue containing GenerationResult objects + inference_results_Q: Queue containing individual GenerationResult objects (one per prompt) pending_queries_map: PendingQueriesMap instance for thread-safe query tracking - args: Arguments containing vllm_num_engines + args: Arguments containing vllm_num_engines and batch size info training_step: Current training step for error reporting generation_config: Generation config containing n (number of samples per prompt) + num_prompts: Number of prompts to accumulate timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely. Raises: @@ -1363,15 +1495,15 @@ def accumulate_inference_batches( Returns: Tuple of (combined_result, Batch with queries, ground_truths, datasets) or (ShutdownSentinel, None) if shutdown signal received """ - # Collect results from all engines with non-blocking progress bar results = [] all_queries = [] all_ground_truths = [] all_datasets = [] + for i in tqdm( - range(args.vllm_num_engines), - total=args.vllm_num_engines, - desc=f"Accumulating results from {args.vllm_num_engines} engines", + range(num_prompts), + total=num_prompts, + desc=f"Accumulating {num_prompts} results (each with {generation_config.n} completions)", bar_format="{l_bar}{bar}{r_bar}\n", disable=not args.verbose, ): @@ -1379,37 +1511,17 @@ def accumulate_inference_batches( if isinstance(result, ShutdownSentinel): return result, None - dataset_indices = result.dataset_index - - if dataset_indices is None: - raise RuntimeError(f"Dataset indices is None for result {i}") - - # When generation_config.n > 1, vLLM generates multiple responses per prompt - # but dataset_indices only contains the unique indices (not replicated) - # So we expect: len(responses) == len(dataset_indices) * generation_config.n - expected_responses = len(dataset_indices) * generation_config.n - assert len(result.responses) == expected_responses, ( - f"Mismatch: number of responses ({len(result.responses)}) " - f"doesn't match expected ({expected_responses}) for result {i}" - f". {generation_config.n=}" - f", {len(dataset_indices)=}" - ) - # Get corresponding queries, ground_truths, datasets for each individual prompt - batch_queries = [] - batch_ground_truths = [] - batch_datasets = [] - - for dataset_idx in dataset_indices: - query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) - batch_queries.append(query) - batch_ground_truths.append(ground_truth) - batch_datasets.append(dataset) + dataset_index = result.dataset_index + assert len(result.responses) == generation_config.n, ( + f"Result {i} has {len(result.responses)} responses, expected {generation_config.n}" + ) + query, ground_truth, dataset = pending_queries_map.pop(dataset_index) + all_queries.append(query) + all_ground_truths.append(ground_truth) + all_datasets.append(dataset) results.append(result) - all_queries.extend(batch_queries) - all_ground_truths.extend(batch_ground_truths) - all_datasets.extend(batch_datasets) # Combine all results into a single GenerationResult combined_responses = [] @@ -1492,7 +1604,13 @@ def data_preparation_thread( # Streaming accumulation: collect results as they arrive with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: result, batch = accumulate_inference_batches( - inference_results_Q, pending_queries_map, args, training_step, generation_config, actor_manager + inference_results_Q, + pending_queries_map, + args, + training_step, + generation_config, + args.num_unique_prompts_rollout, + actor_manager ) if isinstance(result, ShutdownSentinel): logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") @@ -1626,10 +1744,6 @@ def data_preparation_thread( finish_reasons += [finish_reasons[i] for i in sampled_indices] - print( - f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses" - ) - with Timer("📦 [Data Preparation Thread] Packing sequences"): packed_sequences = pack_sequences( queries=batch.queries, @@ -1656,7 +1770,8 @@ def data_preparation_thread( shortfall = args.world_size - len(packed_sequences.query_responses) if shortfall > 0: logger.warning( - f"Padding {shortfall} sequences for world size. In future, you should adjust your compute this." + f"[Data Preparation Thread] Step {training_step}: Padding {shortfall} sequences for world size. " + f"In future, you should adjust your compute this." ) # construct "dummy" sequences for padding out the world size dummy_qr = torch.tensor([tokenizer.pad_token_id, tokenizer.eos_token_id], dtype=torch.long) @@ -1855,6 +1970,7 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod name=args.run_name, save_code=True, tags=[args.exp_name] + get_wandb_tags(), + settings={"quiet": False, "show_errors": True, "show_warnings": True, "silent": False}, ) wandb_url = wandb.run.get_url() logger.info(f"Initial Beaker description update with wandb_url: {wandb_url}") @@ -1923,7 +2039,6 @@ def create_model_and_optimizer( ray_get_with_progress([pg.ready()], desc="Waiting for placement group") inits = [] policy_group = ModelGroup(pg, PolicyTrainerRayProcess, args.num_learners_per_node, args.single_gpu_mode) - wandb_url = wandb.run.get_url() if args.with_tracking else None inits.extend( model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) for model in policy_group.models @@ -1976,6 +2091,7 @@ def create_model_and_optimizer( max_len, args.vllm_gpu_memory_utilization, args.single_gpu_mode, + args.inference_batch_size, pg=pg if args.single_gpu_mode else None, tools=tool_objects, max_tool_calls=args.max_tool_calls, @@ -1983,6 +2099,7 @@ def create_model_and_optimizer( results_queue=inference_results_Q, eval_results_queue=evaluation_inference_results_Q, actor_manager=actor_manager, + inflight_updates=args.inflight_updates, ) resume_training_step = ray_get_with_progress(inits, desc="Initializing models")[0] + 1 @@ -2037,28 +2154,14 @@ def split_and_insert_batch( is_eval: bool = False, ) -> None: """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" - for batch_idx in range(vllm_num_engines): - start_idx = batch_idx * args.inference_batch_size - end_idx = min(start_idx + args.inference_batch_size, len(batch.queries)) - - # Stop if we've distributed all queries - if start_idx >= len(batch.queries): - break - - sub_batch = batch[start_idx:end_idx] - - # Store prompts in the map using thread-safe insert_many - pending_queries_map.insert_many( - sub_batch.indices, sub_batch.queries, sub_batch.ground_truths, sub_batch.datasets - ) - - # Use PromptRequest for Ray queue with batch-specific dataset_index list + pending_queries_map.insert_many(batch.indices, batch.queries, batch.ground_truths, batch.datasets) + for i, prompt in enumerate(batch.queries): param_prompt_Q.put( PromptRequest( - prompts=sub_batch.queries, + prompt=prompt, generation_config=generation_config, training_step=training_step, - dataset_index=sub_batch.indices, + dataset_index=batch.indices[i], is_eval=is_eval, ) ) @@ -2101,7 +2204,7 @@ def load_data_from_packing_thread( logger.warning("[Main Thread] Stop event detected while waiting for packed sequences") return None, {}, num_total_tokens try: - packed_data = packed_sequences_Q.get(timeout=30.0) + packed_data = packed_sequences_Q.get(timeout=60.0) break except Empty: # check that everything is still alive @@ -2182,9 +2285,11 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera enable=args.verbose, ) num_processed = sum(int(result) for result in processed_results) + logger.info(f"[Generate Thread] vLLM engines returned, processed {num_processed} total requests") # Suppress timing output if nothing was processed if num_processed == 0: timer.noop = True + time.sleep(10) if num_processed > 0: try: generate_metrics_Q.put_nowait({"time/generation": timer.duration}) @@ -2282,7 +2387,21 @@ def one_training_step( if isinstance(value, np.ndarray) or isinstance(value, list): if len(value) > 0: metrics[key] = wandb.Histogram(value) - wandb.log(metrics, step=episode) + logger.info(f"About to log to wandb... step={episode}") + result = wandb.log(metrics, step=episode) + logger.info(f"WandB log() returned: {result}, logged {metrics=} to wandb.") + + # Debug: Check if wandb.run is active and what's in the summary + if wandb.run: + logger.info(f"WandB run is active. Run ID: {wandb.run.id}") + logger.info(f"WandB run name: {wandb.run.name}") + logger.info(f"WandB run URL: {wandb.run.get_url()}") + # Check if data was actually logged + summary_keys = list(wandb.run.summary.keys()) if wandb.run.summary else [] + logger.info(f"WandB summary has {len(summary_keys)} keys: {summary_keys[:10]}...") # Show first 10 keys + else: + logger.error("WandB run is not active!") + logger.info("Done step.") def maybe_evaluate( @@ -2296,6 +2415,7 @@ def maybe_evaluate( eval_pending_queries_map: PendingQueriesMap, eval_generation_config, generate_metrics_Q: Queue, + num_eval_prompts: int, actor_manager=None, ): """Optionally evaluate the model.""" @@ -2311,6 +2431,7 @@ def maybe_evaluate( args, training_step, eval_generation_config, + num_prompts=num_eval_prompts, actor_manager, timeout=timeout, ) @@ -2548,6 +2669,7 @@ def run_training( tokenizer, train_dataset, eval_batch, + num_eval_prompts, policy_group, vllm_engines, generation_configs, @@ -2641,7 +2763,7 @@ def health_check_fn(): if ( training_step == resume_training_step - or training_step % 10 == 0 + or training_step % args.update_every == 0 or training_step == args.num_training_steps ): logger.info(f"Progress update for Beaker description: step {training_step}/{args.num_training_steps}") @@ -2754,6 +2876,7 @@ def health_check_fn(): eval_pending_queries_map, generation_configs["eval"], generate_metrics_Q, + num_eval_prompts, actor_manager, ) @@ -2778,10 +2901,11 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa ray.init(dashboard_host="0.0.0.0") # Create Ray queues. - queue_size = (args.async_steps + 1) * args.vllm_num_engines + queue_size = (args.async_steps + 1) * args.num_unique_prompts_rollout inference_results_Q = ray_queue.Queue(maxsize=queue_size) param_prompt_Q = ray_queue.Queue(maxsize=queue_size) - evaluation_inference_results_Q = ray_queue.Queue(maxsize=args.vllm_num_engines) + # Size eval queue based on actual eval samples (num_eval_samples) with 2x overhead for buffering. + evaluation_inference_results_Q = ray_queue.Queue(maxsize=num_eval_samples * 2) policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager = ( create_model_and_optimizer( @@ -2826,9 +2950,11 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa if eval_dataset is None: eval_batch = None + num_eval_prompts = 0 else: eval_dataset_indices = list(range(min(num_eval_samples, len(eval_dataset)))) eval_batch = next_batch(eval_dataset_indices, eval_dataset) + num_eval_prompts = len(eval_dataset_indices) reward_fn = make_reward_fn(args) stop_event = threading.Event() @@ -2840,6 +2966,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa tokenizer, train_dataset, eval_batch, + num_eval_prompts, policy_group, vllm_engines, generation_configs, @@ -2862,6 +2989,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa actor_manager, checkpoint_state, ) + logger.info("Done training!") finally: cleanup_training_resources( stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager diff --git a/open_instruct/queue_types.py b/open_instruct/queue_types.py index 87f247ad6..78bb5d760 100644 --- a/open_instruct/queue_types.py +++ b/open_instruct/queue_types.py @@ -31,7 +31,7 @@ class GenerationResult: finish_reasons: List[str] masks: List[List[int]] request_info: RequestInfo - dataset_index: Optional[List[int]] = None + dataset_index: Optional[int] = None training_step: Optional[int] = None token_statistics: Optional[TokenStatistics] = None start_time: Optional[float] = None @@ -46,9 +46,9 @@ class PromptRequest: `_QueueActor`. """ - prompts: List[List[int]] + prompt: List[int] generation_config: Any training_step: Optional[int] = None - dataset_index: Optional[List[int]] = None + dataset_index: Optional[int] = None is_eval: bool = False start_time: Optional[float] = None diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index be9cbd417..f2edbc323 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -1,5 +1,7 @@ import gc import os +import queue +import time import unittest from unittest.mock import Mock @@ -11,9 +13,8 @@ from transformers import AutoTokenizer from vllm import SamplingParams -from open_instruct import grpo_fast, model_utils, utils +from open_instruct import grpo_fast, model_utils, tool_utils, utils, vllm_utils3 from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo -from open_instruct.vllm_utils3 import create_vllm_engines class TestGrpoFastBase(unittest.TestCase): @@ -67,19 +68,10 @@ def setUp(self): # Initialize Ray for this test ray.init(include_dashboard=False) - def _cleanup_ray_queues(self): - """Clean up all Ray queues created during the test.""" - for queue in self._ray_queues: - try: - queue.shutdown() - except Exception as e: - print(f"Warning: Failed to shutdown Ray queue: {e}") - self._ray_queues.clear() - def tearDown(self): """Check for leaks and shutdown Ray.""" # Clean up Ray queues BEFORE shutting down Ray - self._cleanup_ray_queues() + [rq.shutdown() for rq in self._ray_queues] # Shutdown Ray if ray.is_initialized(): @@ -126,17 +118,18 @@ def create_test_data(self, num_prompts, prefix="", start_idx=0): datasets = [f"{prefix}dataset_{i}" for i in indices] return queries, ground_truths, datasets, indices - def create_mock_args(self, num_engines=4, num_samples=1): + def create_mock_args(self, num_engines=4, num_samples=1, num_prompts=16): """Create mock args object.""" mock_args = Mock() mock_args.vllm_num_engines = num_engines mock_args.num_samples_per_prompt_rollout = num_samples + mock_args.num_unique_prompts_rollout = num_prompts + mock_args.verbose = False return mock_args - def create_mock_result(self, dataset_indices, training_step, num_samples_per_prompt=1): + def create_mock_result(self, dataset_index, training_step, num_samples_per_prompt=1): """Create a mock GenerationResult.""" - batch_size = len(dataset_indices) - total_responses = batch_size * num_samples_per_prompt + total_responses = num_samples_per_prompt return GenerationResult( responses=[[1, 2, 3] for _ in range(total_responses)], @@ -150,13 +143,15 @@ def create_mock_result(self, dataset_indices, training_step, num_samples_per_pro tool_runtimes=[0.0] * total_responses, tool_calleds=[False] * total_responses, ), - dataset_index=dataset_indices, + dataset_index=dataset_index, ) def setup_and_split_batch(self, queries, ground_truths, datasets, indices, num_engines, training_step=1): """Setup queues and split batch - common pattern.""" - param_prompt_Q = ray_queue.Queue(maxsize=num_engines * 2) - inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2) + # Queue size should accommodate batches from all engines potentially multiple times + queue_size = num_engines * len(queries) + param_prompt_Q = ray_queue.Queue(maxsize=queue_size) + inference_results_Q = ray_queue.Queue(maxsize=queue_size) pending_queries_map = grpo_fast.PendingQueriesMap() # Track queues for cleanup @@ -205,7 +200,7 @@ def test_vllm_queue_system_single_prompt(self): self._ray_queues.extend([param_prompt_Q, inference_results_Q]) # Create vLLM engines with queues - vllm_engines = create_vllm_engines( + vllm_engines = vllm_utils3.create_vllm_engines( num_engines=1, tensor_parallel_size=1, enforce_eager=True, @@ -268,43 +263,49 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines ) - # Verify that we have individual prompts in the map (not batches) + # Verify that we have all prompts in the map self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout) - # Verify that we have the expected number of items in the queue - self.assertEqual(param_prompt_Q.qsize(), vllm_num_engines) + # Verify that we have individual prompts in the queue (changed from batches) + self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout) - # Simulate vLLM processing - batch_idx = 0 + # Simulate vLLM processing - each prompt gets processed individually + prompt_count = 0 while not param_prompt_Q.empty(): request = param_prompt_Q.get() self.assertIsInstance(request, PromptRequest) self.assertEqual(request.training_step, 1) - self.assertIsInstance(request.dataset_index, list) + self.assertIsInstance(request.dataset_index, int) # Now individual prompts have single index - mock_result = self.create_mock_result(request.dataset_index, request.training_step) + # Create result for this individual prompt with n=4 samples + mock_result = self.create_mock_result( + request.dataset_index, request.training_step, num_samples_per_prompt=4 + ) inference_results_Q.put(mock_result) - batch_idx += 1 + prompt_count += 1 + + # Verify we processed the right number of individual prompts + self.assertEqual(prompt_count, num_unique_prompts_rollout) - # Simulate streaming accumulation (simplified version for testing) + # Simulate accumulation combined_responses = [] combined_queries = [] combined_ground_truths = [] combined_datasets = [] - for _ in range(vllm_num_engines): + # Process all results (we have num_unique_prompts_rollout individual results) + for _ in range(num_unique_prompts_rollout): result = inference_results_Q.get() - dataset_indices = result.dataset_index + dataset_index = result.dataset_index - # Get queries from pending_queries_map + # Get query for this index batch_queries = [] batch_ground_truths = [] batch_datasets = [] - for idx in dataset_indices: - q, gt, d = pending_queries_map.pop(idx) - batch_queries.append(q) - batch_ground_truths.append(gt) - batch_datasets.append(d) + q, gt, d = pending_queries_map.pop(dataset_index) + batch_queries.append(q) + batch_ground_truths.append(gt) + batch_datasets.append(d) combined_responses.extend(result.responses) combined_queries.extend(batch_queries) @@ -333,9 +334,10 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, # Verify that the combined result has the correct structure self.assertIsInstance(combined_result, GenerationResult) - self.assertEqual(len(combined_result.responses), len(queries_next)) - self.assertEqual(len(combined_result.finish_reasons), len(queries_next)) - self.assertEqual(len(combined_result.masks), len(queries_next)) + # With n=4 samples per prompt, we expect 4x the number of responses + self.assertEqual(len(combined_result.responses), len(queries_next) * 4) + self.assertEqual(len(combined_result.finish_reasons), len(queries_next) * 4) + self.assertEqual(len(combined_result.masks), len(queries_next) * 4) # Verify that the pending_queries_map is empty after accumulation self.assertEqual(len(pending_queries_map), 0) @@ -358,28 +360,28 @@ def test_dataset_index_preservation_through_pipeline(self): queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines ) - # Simulate vLLM processing - batch_idx = 0 + # Simulate vLLM processing - processes individual prompts + prompt_count = 0 while not param_prompt_Q.empty(): request = param_prompt_Q.get() mock_result = self.create_mock_result(request.dataset_index, request.training_step) inference_results_Q.put(mock_result) - batch_idx += 1 + prompt_count += 1 - # Simulate streaming accumulation + # Simulate accumulation combined_queries = [] combined_ground_truths = [] combined_datasets = [] - for _ in range(vllm_num_engines): + # Process all individual results + for _ in range(prompt_count): result = inference_results_Q.get() - dataset_indices = result.dataset_index + dataset_index = result.dataset_index - for idx in dataset_indices: - q, gt, d = pending_queries_map.pop(idx) - combined_queries.append(q) - combined_ground_truths.append(gt) - combined_datasets.append(d) + q, gt, d = pending_queries_map.pop(dataset_index) + combined_queries.append(q) + combined_ground_truths.append(gt) + combined_datasets.append(d) # Verify results self.assertEqual(combined_queries, queries_next) @@ -402,32 +404,33 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe queries_next, ground_truths_next, datasets_next, dataset_indices, vllm_num_engines ) - # Simulate vLLM processing with multiple samples - batch_idx = 0 + # Simulate vLLM processing with multiple samples - processes individual prompts + prompt_count = 0 while not param_prompt_Q.empty(): request = param_prompt_Q.get() mock_result = self.create_mock_result(request.dataset_index, request.training_step, num_samples_per_prompt) inference_results_Q.put(mock_result) - batch_idx += 1 + prompt_count += 1 - # Simulate streaming accumulation + # Simulate accumulation combined_responses = [] combined_queries = [] combined_ground_truths = [] combined_datasets = [] - for _ in range(vllm_num_engines): + # Process all individual results + for _ in range(prompt_count): result = inference_results_Q.get() - dataset_indices = result.dataset_index + dataset_index = result.dataset_index + # Get query for this index batch_queries = [] batch_ground_truths = [] batch_datasets = [] - for idx in dataset_indices: - q, gt, d = pending_queries_map.pop(idx) - batch_queries.append(q) - batch_ground_truths.append(gt) - batch_datasets.append(d) + q, gt, d = pending_queries_map.pop(dataset_index) + batch_queries.append(q) + batch_ground_truths.append(gt) + batch_datasets.append(d) combined_responses.extend(result.responses) combined_queries.extend(batch_queries) @@ -521,12 +524,13 @@ def test_out_of_order_processing(self): requests.append(param_prompt_Q.get()) # Put results back in REVERSE order to simulate out-of-order processing + # Create one result per prompt with n responses for request in reversed(requests): mock_result = self.create_mock_result(request.dataset_index, request.training_step, num_samples_per_prompt) inference_results_Q.put(mock_result) # Accumulate results - mock_args = self.create_mock_args(num_engines, num_samples_per_prompt) + mock_args = self.create_mock_args(num_engines, num_samples_per_prompt, num_prompts) # Create a mock generation config with n mock_generation_config = Mock() mock_generation_config.n = num_samples_per_prompt @@ -537,6 +541,7 @@ def test_out_of_order_processing(self): mock_args, training_step=1, generation_config=mock_generation_config, + num_prompts=num_prompts, ) # Verify results work correctly even with out-of-order processing @@ -592,8 +597,9 @@ def test_accumulate_waits_for_all_engines(self): num_engines = 4 num_prompts = 16 - # Setup with results from only 3 engines - inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2) + # Setup with results from only 3 engines (missing one) + # Queue size should accommodate results from engines + inference_results_Q = ray_queue.Queue(maxsize=num_engines * num_prompts) # Track queue for cleanup self._ray_queues.append(inference_results_Q) @@ -604,15 +610,16 @@ def test_accumulate_waits_for_all_engines(self): for i in range(num_prompts): pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}") - # Add results from only 3 engines (missing one) - for engine_id in range(3): - indices = list(range(engine_id * 4, (engine_id + 1) * 4)) - mock_result = self.create_mock_result(indices, 1) + # Add individual results (one per prompt) but missing some + # accumulate_inference_batches now expects individual results (num_prompts * n) + # Add results for only 12 prompts (missing 4) + for i in range(12): # Only 12 prompts, missing 4 + mock_result = self.create_mock_result(i, 1) inference_results_Q.put(mock_result) - mock_args = self.create_mock_args(num_engines) + mock_args = self.create_mock_args(num_engines, num_prompts=num_prompts) - # Test that accumulate blocks when missing an engine + # Test that accumulate blocks when missing results from the 4th engine import threading completed = threading.Event() @@ -629,6 +636,7 @@ def run_accumulate(): mock_args, training_step=1, generation_config=mock_generation_config, + num_prompts=num_prompts, ) completed.set() except Exception: @@ -637,14 +645,14 @@ def run_accumulate(): thread = threading.Thread(target=run_accumulate, daemon=True) thread.start() - # Should timeout waiting for 4th engine + # Should timeout waiting for missing results self.assertFalse(completed.wait(timeout=1.0)) self.assertTrue(thread.is_alive()) - # Queue should be empty after consuming 3 results + # Queue should be empty after consuming 12 results self.assertEqual(inference_results_Q.qsize(), 0) - # Some entries should be removed - self.assertLess(len(pending_queries_map), num_prompts) + # 12 entries should be removed from the map (4 still pending) + self.assertEqual(len(pending_queries_map), 4) class TestStreamingAccumulation(TestGrpoFastBase): @@ -695,19 +703,19 @@ def test_more_engines_than_queries(self): while not param_prompt_Q.empty(): request = param_prompt_Q.get() self.assertIsInstance(request, PromptRequest) - self.assertEqual(len(request.prompts), 1, "Each batch should have exactly 1 prompt") - batch_sizes.append(len(request.prompts)) + self.assertIsNotNone(request.prompt, "Each request should have a prompt") + batch_sizes.append(1) # Each PromptRequest contains exactly 1 prompt # All queries should be in the pending map self.assertEqual(len(pending_queries_map), num_queries) def test_uneven_distribution_no_empty_batches(self): - """Test that uneven query distribution doesn't create empty batches.""" + """Test that split_and_insert_batch creates one PromptRequest per query.""" num_engines = 3 - num_queries = 7 # 7/3 = ceil(2.33) = 3, so distribution should be [3, 3, 1] + num_queries = 7 queries, ground_truths, datasets, indices = self.create_test_data(num_queries) - param_prompt_Q = ray_queue.Queue(maxsize=num_engines * 2) + param_prompt_Q = ray_queue.Queue(maxsize=num_queries * 2) pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup @@ -735,24 +743,15 @@ def test_uneven_distribution_no_empty_batches(self): args=mock_args, ) - # Verify all batches have content and check distribution - batch_sizes = [] + # Verify we get one PromptRequest per query + request_count = 0 while not param_prompt_Q.empty(): request = param_prompt_Q.get() - self.assertGreater(len(request.prompts), 0, "Found empty batch in queue!") - batch_sizes.append(len(request.prompts)) - - # Check the expected distribution - self.assertEqual(sum(batch_sizes), num_queries, "Total queries should match") - self.assertEqual(len(batch_sizes), num_engines, "Should have one batch per engine") + self.assertIsNotNone(request.prompt, "Each request should have a prompt") + request_count += 1 - # The distribution should be [3, 3, 1] for 7 queries across 3 engines with ceiling division - expected_distribution = [3, 3, 1] - self.assertEqual( - sorted(batch_sizes, reverse=True), - expected_distribution, - f"Expected distribution {expected_distribution}, got {sorted(batch_sizes, reverse=True)}", - ) + # Should have exactly num_queries PromptRequests + self.assertEqual(request_count, num_queries, f"Should have {num_queries} PromptRequests") def test_streaming_accumulation_basic(self): """Test basic streaming accumulation with in-order results.""" @@ -763,7 +762,8 @@ def test_streaming_accumulation_basic(self): queries, ground_truths, datasets, indices = self.create_test_data(num_prompts) # Create queues and maps - inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2) + # Queue size should accommodate results from engines + inference_results_Q = ray_queue.Queue(maxsize=num_engines * num_prompts) pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup @@ -773,45 +773,40 @@ def test_streaming_accumulation_basic(self): for i in range(num_prompts): pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i]) - # Create mock results with batch indices - batch_size = num_prompts // num_engines - for batch_idx in range(num_engines): - start = batch_idx * batch_size - end = start + batch_size - mock_result = self.create_mock_result(list(range(start, end)), training_step=1) + # Create mock results - one per prompt + for i in range(num_prompts): + mock_result = self.create_mock_result(i, training_step=1) inference_results_Q.put(mock_result) # Simulate streaming accumulation logic results_list = [] queries_list = [] - expected_batches = num_engines + expected_results = num_prompts # Now expecting one result per prompt - while len(results_list) < expected_batches: + while len(results_list) < expected_results: result = inference_results_Q.get() - batch_idx = len(results_list) results_list.append(result) - # Get queries for this batch - dataset_indices = result.dataset_index + # Get query for this prompt + dataset_index = result.dataset_index batch_queries = [] batch_ground_truths = [] batch_datasets = [] - for idx in dataset_indices: - q, gt, d = pending_queries_map.pop(idx) - batch_queries.append(q) - batch_ground_truths.append(gt) - batch_datasets.append(d) + q, gt, d = pending_queries_map.pop(dataset_index) + batch_queries.append(q) + batch_ground_truths.append(gt) + batch_datasets.append(d) queries_list.append((batch_queries, batch_ground_truths, batch_datasets)) - # Verify all batches processed - self.assertEqual(len(results_list), expected_batches) + # Verify all results processed + self.assertEqual(len(results_list), expected_results) self.assertEqual(len(pending_queries_map), 0) # Combine in order combined_queries = [] - for i in range(num_engines): + for i in range(num_prompts): q, _, _ = queries_list[i] combined_queries.extend(q) @@ -828,7 +823,8 @@ def test_streaming_with_multiple_samples(self): queries, ground_truths, datasets, indices = self.create_test_data(num_prompts) # Create queues and maps - inference_results_Q = ray_queue.Queue(maxsize=num_engines * 2) + # Queue size should accommodate results from engines + inference_results_Q = ray_queue.Queue(maxsize=num_engines * num_prompts) pending_queries_map = grpo_fast.PendingQueriesMap() # Track queue for cleanup @@ -839,15 +835,9 @@ def test_streaming_with_multiple_samples(self): for _ in range(num_samples): pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i]) - # Create results with multiple samples per prompt - batch_size = num_prompts // num_engines - for batch_idx in range(num_engines): - start = batch_idx * batch_size - end = start + batch_size - dataset_indices = list(range(start, end)) - - # Create result with num_samples responses per prompt - mock_result = self.create_mock_result(dataset_indices, training_step=1, num_samples_per_prompt=num_samples) + # Create results - one per prompt with multiple samples + for i in range(num_prompts): + mock_result = self.create_mock_result(i, training_step=1, num_samples_per_prompt=num_samples) inference_results_Q.put(mock_result) # Process results @@ -855,22 +845,125 @@ def test_streaming_with_multiple_samples(self): while not inference_results_Q.empty(): result = inference_results_Q.get() - # Verify number of responses matches num_samples * num_prompts_in_batch - batch_prompts = len(result.dataset_index) + batch_prompts = 1 # Each result is for a single prompt now expected_responses = batch_prompts * num_samples self.assertEqual(len(result.responses), expected_responses) total_responses += len(result.responses) # Clean up pending_queries_map - for idx in result.dataset_index: - for _ in range(num_samples): - if idx in pending_queries_map: - pending_queries_map.pop(idx) + idx = result.dataset_index + for _ in range(num_samples): + if idx in pending_queries_map: + pending_queries_map.pop(idx) # Verify total responses self.assertEqual(total_responses, num_prompts * num_samples) self.assertEqual(len(pending_queries_map), 0) + def test_vllm_tool_processing_completes(self): + """Test that tool processing completes without hanging using actual tool settings.""" + # Check if CUDA is available + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available, skipping test") + + # Create actual tools matching the script + tools = { + "": tool_utils.CodeExecutionTool( + start_str="", + end_str="", + api_endpoint="https://open-instruct-tool-server-10554368204.us-central1.run.app/execute", + ), + "": tool_utils.SearchTool( + start_str="", + end_str="", + api_endpoint="http://saturn-cs-aus-232.reviz.ai2.in:44177/search", + ), + } + max_tool_calls = {"": 5, "": 5} + + # Create actual ActorManager via ray.remote + actor_manager = ray.remote(vllm_utils3.ActorManager).remote(should_stop=False) + + # Create queues + prompt_queue = queue.Queue() + results_queue = queue.Queue() + eval_results_queue = queue.Queue() + + # Create LLMRayActor + model_name = "EleutherAI/pythia-14m" # Small model for testing + actor = vllm_utils3.LLMRayActor( + model_name_or_path=model_name, + actor_id=0, + prompt_queue=prompt_queue, + results_queue=results_queue, + eval_results_queue=eval_results_queue, + actor_manager=actor_manager, + tools=tools, + max_tool_calls=max_tool_calls, + inference_batch_size=8, + vllm_kwargs={"gpu_memory_utilization": 0.3, "max_model_len": 512, "enable_prefix_caching": True}, + ) + + tokenizer = actor.llm_engine.tokenizer + + # Create test prompts that will trigger tools + test_prompts = [ + "Write code to print hello: print('hello')", + "Search for Python tutorials: Python tutorial beginner", + "Calculate 2+2: print(2+2)", + "Find vLLM documentation: vLLM documentation", + "Write a simple function: def greet(): return 'hi'", + "Search machine learning: machine learning basics", + "Debug this code: x = 1; print(x)", + "Look up PyTorch: PyTorch tutorial", + ] + + # Create sampling params once + sampling_params = SamplingParams( + temperature=1.0, top_p=1.0, n=1, max_tokens=50, stop=["", "", ""] + ) + + # Add all requests to queue + for i in range(16): + prompt_text = test_prompts[i % len(test_prompts)] + prompt_ids = tokenizer.encode(prompt_text, add_special_tokens=True) + + request = PromptRequest( + prompt=prompt_ids, generation_config=sampling_params, training_step=0, dataset_index=i, is_eval=False + ) + prompt_queue.put(request) + + # Set a timeout and process + start_time = time.time() + + # Process requests - this is what we're testing! + num_processed = actor.process_from_queue(timeout=30.0) + + elapsed = time.time() - start_time + + # Check results + results_received = [] + while not results_queue.empty(): + try: + result = results_queue.get_nowait() + results_received.append(result) + except queue.Empty: + break + + # Verify we didn't hang + self.assertLess(elapsed, 60, "Should complete in less than 60 seconds") + + # Verify we got results + self.assertGreater(num_processed, 0, "Should have processed some requests") + self.assertGreater(len(results_received), 0, "Should have received some results") + + # Verify tool processing worked + for result in results_received: + self.assertIsNotNone(result.responses) + self.assertIsNotNone(result.request_info) + self.assertIsNotNone(result.request_info.tool_calleds) + # Check that at least some tools were called + class TestShufflingIterator(unittest.TestCase): """Test ShufflingIterator state preservation functionality.""" diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index bb427f391..e8d9d878e 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -43,7 +43,7 @@ from vllm.v1.core import kv_cache_utils from open_instruct import logger_utils -from open_instruct.queue_types import GenerationResult, RequestInfo, TokenStatistics +from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics from open_instruct.tool_utils.tool_vllm import MaxCallsExceededTool, Tool from open_instruct.utils import ray_get_with_progress @@ -201,37 +201,31 @@ def _process_outputs_with_tools( def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=None, start_time=None): - """Prepare final outputs based on whether tools were used.""" + """Prepare final outputs with unified approach for all requests.""" + outputs.sort(key=lambda x: (x.request_id.split("-")[0], int(x.request_id.split("-")[1]))) + if not tools: - outputs.sort(key=lambda x: int(x.request_id.split("_")[-1])) return _process_outputs( outputs, dataset_index=dataset_index, token_statistics=token_statistics, start_time=start_time ) # Tool mode: add metadata and merge completions - for req_id in tracking["masks"]: + for output in outputs: + req_id = output.request_id + if req_id not in tracking["masks"]: + # If the request ID is not in masks, it means it didn't go through tool processing. + continue assert req_id in tracking["concat_outputs"], f"req_id {req_id} not in concat_outputs!" - output = tracking["concat_outputs"][req_id].outputs[0] - setattr(output, "mask", tracking["masks"][req_id]) - setattr(output, "num_calls", tracking["num_calls"][req_id]) - setattr(output, "timeout", tracking["timeout"][req_id]) - setattr(output, "tool_error", tracking["tool_error"][req_id]) - setattr(output, "tool_output", tracking["tool_output"][req_id]) - setattr(output, "tool_runtime", tracking["tool_runtime"][req_id]) - setattr(output, "tool_called", tracking["tool_called"][req_id]) - - # Merge n completions into the same outputs - merged_outputs = {} - for req_id in tracking["concat_outputs"]: - real_req_id, _ = req_id.split("-") - if real_req_id not in merged_outputs: - merged_outputs[real_req_id] = tracking["concat_outputs"][req_id] - else: - merged_outputs[real_req_id].outputs.append(tracking["concat_outputs"][req_id].outputs[0]) - - final_outputs = sorted( - merged_outputs.values(), key=lambda x: (int(x.request_id.split("-")[0]), int(x.request_id.split("-")[1])) - ) + tool_output = tracking["concat_outputs"][req_id].outputs[0] + setattr(tool_output, "mask", tracking["masks"][req_id]) + setattr(tool_output, "num_calls", tracking["num_calls"][req_id]) + setattr(tool_output, "timeout", tracking["timeout"][req_id]) + setattr(tool_output, "tool_error", tracking["tool_error"][req_id]) + setattr(tool_output, "tool_output", tracking["tool_output"][req_id]) + setattr(tool_output, "tool_runtime", tracking["tool_runtime"][req_id]) + setattr(tool_output, "tool_called", tracking["tool_called"][req_id]) + # Replace the output with the tool-processed one + output.outputs[0] = tool_output return _process_outputs_with_tools( final_outputs, dataset_index=dataset_index, token_statistics=token_statistics, start_time=start_time @@ -317,6 +311,36 @@ def init_process_group( return pg +def add_request(request: PromptRequest, llm_engine: vllm.LLMEngine, tools, request_metadata: dict = None): + """Add a request to the LLM engine.""" + prefix = "eval" if request.is_eval else "train" + request_id = f"{prefix}_{request.training_step}_{request.dataset_index}" + metadata = { + "is_eval": request.is_eval, + "dataset_index": request.dataset_index, + "training_step": request.training_step, + "sampling_params": request.generation_config, + } + + tokens_prompt = vllm.TokensPrompt(prompt_token_ids=request.prompt, cache_salt=request_id) + + # We *have* to manually duplicate requests to properly handle tool tracking, + # so we always do it to only have one code path. + # Create sampling params with n=1 for individual tracking + sampling_params = copy.deepcopy(request.generation_config) + sampling_params.n = 1 + metadata["sampling_params"] = sampling_params + metadata["generation_config"] = request.generation_config + request_metadata[request_id] = metadata + for j in range(request.generation_config.n): + sub_request_id = f"{request_id}-{j}" + if request.generation_config.seed is not None: + # We need to seed each sub-request differently to avoid getting the same output. + sampling_params.seed = request.generation_config.seed + j + llm_engine.add_request(sub_request_id, tokens_prompt, sampling_params) + request_metadata[sub_request_id] = metadata + + class LLMRayActor: """Ray actor for LLM generation with optional tool support.""" @@ -330,11 +354,14 @@ def __init__( results_queue=None, eval_results_queue=None, actor_manager=None, + inference_batch_size: Optional[int] = None, + inflight_updates: bool = False, **kwargs, ): self.logger = logger_utils.setup_logger(__name__) self.tools = tools or {} self.max_tool_calls = max_tool_calls or {} + self.inflight_updates = inflight_updates self.request_metadata = {} if self.tools: @@ -367,150 +394,255 @@ def __init__( self.results_queue = results_queue self.eval_results_queue = eval_results_queue self.actor_manager = actor_manager + if inference_batch_size is None: + raise ValueError("inference_batch_size must be specified.") + self.inference_batch_size = inference_batch_size + self.request_metadata = {} + + # For caching should_stop status. + self._last_should_stop_update = None + self._should_stop_value = None + self._should_stop_timeout_s = 5 + + def _should_stop(self) -> bool: + last_update = self._last_should_stop_update + if last_update is None or (time.perf_counter() - last_update) > self._should_stop_timeout_s: + should_stop_ref = self.actor_manager.should_stop.remote() + ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1) + if ready_refs: + should_stop = ray.get(ready_refs[0]) + else: + ray.cancel(should_stop_ref) + should_stop = False + self._last_should_stop_update = time.perf_counter() + self._should_stop_value = should_stop + + return self._should_stop_value + + def _insert_result_to_queue(self, result, is_eval: bool): + """Insert result into the appropriate queue with error handling.""" + try: + results_queue = self.eval_results_queue if is_eval else self.results_queue + results_queue.put(result, timeout=10) + except queue.Full: + queue_name = "eval" if is_eval else "train" + self.logger.warning(f"{queue_name} results queue is full, discarding result.") + + def _maybe_add_new_request(self): + """Try to add a new request from the prompt queue if not stopping. + + Returns: + bool: True if a request was added, False otherwise. + """ + if not self._should_stop(): + try: + request = self.prompt_queue.get_nowait() + self.logger.debug( + f"[_maybe_add_new_request] Added new request during processing: " + f"is_eval={request.is_eval}, dataset_index={request.dataset_index}, " + f"training_step={request.training_step}" + ) + add_request(request, self.llm_engine, self.tools, request_metadata=self.request_metadata) + return True + except queue.Empty: + pass # No new request available, continue processing + else: + self.logger.debug("[_maybe_add_new_request] Skipping due to should_stop signal") + return False def process_from_queue(self, timeout: float = 60.0): """Run generation loop using LLMEngine directly, with optional tool support. Returns: - int: Number of requests processed (0 or 1) + int: Number of requests processed. """ - while True: - # Non-blocking check for should_stop using ray.wait - should_stop_ref = self.actor_manager.should_stop.remote() - ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1) - if ready_refs and ray.get(ready_refs[0]): - return 0 + self.logger.info(f"[process_from_queue] Starting with inference_batch_size={self.inference_batch_size}") + num_processed = 0 + overall_start_time = time.perf_counter() + + tracking = _init_tool_tracking() if self.tools else None + tokenizer = self.llm_engine.tokenizer if self.tools else None + + collected_outputs = defaultdict(list) + if self._should_stop(): + self.logger.info("[process_from_queue] Early exit due to should_stop signal") + return num_processed + + # Initial batch loading + batch_load_start = time.perf_counter() + initial_requests = 0 + while initial_requests < self.inference_batch_size: try: request = self.prompt_queue.get(timeout=timeout) + self.logger.debug( + f"[process_from_queue] Got request from queue: " + f"is_eval={request.is_eval}, dataset_index={request.dataset_index}, " + f"training_step={request.training_step}" + ) + add_request(request, self.llm_engine, self.tools, request_metadata=self.request_metadata) + initial_requests += 1 except queue.Empty: - return 0 + # If we couldn't get a request quickly and have some requests, start processing + if self.llm_engine.has_unfinished_requests(): + self.logger.debug( + f"[process_from_queue] Queue empty after {initial_requests} requests, starting processing" + ) + break + # Otherwise continue trying to get more requests - result = self._process_request(request) + batch_load_time = time.perf_counter() - batch_load_start + self.logger.info( + f"[process_from_queue] Initial batch loaded {initial_requests} requests in {batch_load_time:.2f}s" + ) - try: - if request.is_eval: - self.eval_results_queue.put(result, timeout=10) - else: - self.results_queue.put(result, timeout=10) - return 1 # Successfully processed one request - except queue.Full: - self.logger.warning("Results queue is full, discarding result.") - return 0 - - def _process_request(self, request): - """Unified processing for both tool and non-tool generation.""" - prompts = request.prompts - sampling_params = request.generation_config - start_time = request.start_time - - self.logger.info(f"[LLMRayActor] Processing request with {len(prompts)} prompts, tools={bool(self.tools)}") + # Main processing loop + loop_iteration = 0 + # Timing accumulators for every 100 iterations + step_engine_time_acc = 0.0 + output_processing_time_acc = 0.0 + loop_block_start = time.perf_counter() - if self.tools: - # Need n=1 for individual tool tracking - sampling_params = copy.deepcopy(sampling_params) - original_n = request.generation_config.n - sampling_params.n = 1 - tracking = _init_tool_tracking() - tokenizer = self.llm_engine.tokenizer - else: - original_n = 1 - tracking = None - tokenizer = None + while True: + loop_iteration += 1 - self._add_initial_requests(prompts, sampling_params, original_n, request.training_step) + if self._should_stop() and self.inflight_updates: + self.logger.info( + f"[process_from_queue] Stopping due to should_stop signal (inflight_updates=True) after {loop_iteration} iterations" + ) + return num_processed + step_start = time.perf_counter() + outputs = self._step_engine(tracking, tokenizer) + step_engine_time = time.perf_counter() - step_start + step_engine_time_acc += step_engine_time + + self.logger.debug( + f"[process_from_queue] Loop iteration {loop_iteration}: got {len(outputs)} outputs from step_engine" + ) + + output_start = time.perf_counter() + for output in outputs: + request_id = output.request_id.split("-")[0] + collected_outputs[request_id].append(output) + metadata = self.request_metadata[request_id] + + self.logger.debug( + f"[process_from_queue] Collected output for {request_id}: " + f"{len(collected_outputs[request_id])}/{metadata['generation_config'].n} outputs" + ) + + if len(collected_outputs[request_id]) != metadata["generation_config"].n: + continue + + outputs_to_finalize = collected_outputs[request_id] + num_processed += 1 + self.logger.debug( + f"[process_from_queue] Finalizing outputs for {request_id}, total processed: {num_processed}" + ) + + result = _finalize_outputs(outputs_to_finalize, tracking, metadata["dataset_index"], self.tools) + + self._insert_result_to_queue(result, metadata["is_eval"]) + del collected_outputs[request_id] + self.request_metadata.pop(request_id, None) + for i in range(metadata["generation_config"].n): + self.request_metadata.pop(f"{request_id}-{i}", None) + output_processing_time = time.perf_counter() - output_start + output_processing_time_acc += output_processing_time + + unfinished = self.llm_engine.has_unfinished_requests() + if self._should_stop() and not self.inflight_updates: + pending_tool_futures = tracking["pending_tool_futures"] if self.tools else {} + if not unfinished and not pending_tool_futures: + total_time = time.perf_counter() - overall_start_time + self.logger.info( + f"[process_from_queue] Stopping: no unfinished requests or pending tools, " + f"processed {num_processed} requests in {loop_iteration} iterations, total_time={total_time:.2f}s" + ) + break + + # Log timing summary every 100 iterations + if loop_iteration % 100 == 0: + loop_block_time = time.perf_counter() - loop_block_start + self.logger.info( + f"[process_from_queue] Timing (iters {loop_iteration - 99}-{loop_iteration}): " + f"total={loop_block_time:.2f}s, " + f"step_engine={step_engine_time_acc:.2f}s ({step_engine_time_acc / loop_block_time * 100:.1f}%), " + f"output_processing={output_processing_time_acc:.2f}s ({output_processing_time_acc / loop_block_time * 100:.1f}%), " + f"processed={num_processed}" + ) + # Reset accumulators + step_engine_time_acc = 0.0 + output_processing_time_acc = 0.0 + loop_block_start = time.perf_counter() + + total_time = time.perf_counter() - overall_start_time + self.logger.info( + f"[process_from_queue] Completed: processed {num_processed} requests in {total_time:.2f}s (avg {total_time / num_processed:.2f}s per request)" + if num_processed > 0 + else "[process_from_queue] Completed: no requests processed" + ) + return num_processed + + def _step_engine(self, tracking, tokenizer): + """Unified processing for both tool and non-tool generation. + + Returns: + List of completed outputs. + """ outputs = [] - iteration = 0 - while True: - iteration += 1 - - # Poll tool futures first (matching ToolUseLLM order) - if tracking and tracking.get("pending_tool_futures"): - self._poll_tool_futures(tracking, sampling_params, tokenizer) - - # Process engine steps - ONLY if there are unfinished requests (matching ToolUseLLM) - if self.llm_engine.has_unfinished_requests(): - step_outputs = list(self.llm_engine.step()) - for output in step_outputs: - if output.finished: - result = _handle_output( - output, self.tools, tracking, sampling_params, self.max_tool_calls, self.executor - ) - if result is not None: - outputs.append(result) - - # Check termination condition (matching ToolUseLLM exactly) - pending_count = len(tracking["pending_tool_futures"]) if tracking else 0 - if not self.llm_engine.has_unfinished_requests() and pending_count == 0: - self.logger.info(f"[LLMRayActor] Terminating after {iteration} iterations with {len(outputs)} outputs") + if self.tools and tracking["pending_tool_futures"]: + tool_outputs = self._poll_tool_futures(tracking, tokenizer) + outputs.extend(tool_outputs) + if tool_outputs: + self.logger.debug(f"[_step_engine] Got {len(tool_outputs)} outputs from tool futures") + + if self.llm_engine.has_unfinished_requests(): + num_unfinished = self.llm_engine.get_num_unfinished_requests() + self.logger.debug(f"[_step_engine] Stepping engine with {num_unfinished} unfinished requests") + step_outputs = list(self.llm_engine.step()) + self.logger.debug(f"[_step_engine] Engine step returned {len(step_outputs)} outputs") + + for output in step_outputs: + if output.finished: + sampling_params = self.request_metadata[output.request_id]["sampling_params"] + result = _handle_output( + output, self.tools, tracking, sampling_params, self.max_tool_calls, self.executor + ) + if result is not None: + outputs.append(result) + self.logger.debug(f"[_step_engine] Finished output for request {output.request_id}") + + # Try to keep the engine full + added_count = 0 + while not self.llm_engine.get_num_unfinished_requests() < self.inference_batch_size: + if not self._maybe_add_new_request(): break + added_count += 1 - end_time = time.time() - total_prompt_tokens = 0 - total_generation_tokens = 0 - earliest_start_time = float("inf") + return outputs - for output in outputs: - request_id = output.request_id - if request_id in self.request_metadata: - metadata = self.request_metadata[request_id] - total_prompt_tokens += metadata["prompt_tokens"] - earliest_start_time = min(earliest_start_time, metadata["start_time"]) - - for completion in output.outputs: - total_generation_tokens += len(completion.token_ids) - - generation_time = end_time - earliest_start_time - - for output in outputs: - self.request_metadata.pop(output.request_id, None) - - result = _finalize_outputs( - outputs, - tracking, - request.dataset_index, - self.tools, - token_statistics=TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_generation_tokens, - generation_time=generation_time, - ), - start_time=start_time, - ) - return result - - def _add_initial_requests(self, prompts, sampling_params, n_samples, training_step): - """Add initial requests to the engine.""" - for i, prompt in enumerate(prompts): - if self.tools: - # Create individual requests for each sample when using tools - for j in range(n_samples): - request_id = f"{training_step}_{i}-{j}" - self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)} - tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=f"{training_step}_{i}") - self.llm_engine.add_request(request_id, tokens_prompt, sampling_params) - else: - # Standard request format for non-tool mode - request_id = f"batch_{training_step}_{i}" - self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)} - tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id) - self.llm_engine.add_request(request_id, tokens_prompt, sampling_params) - - def _poll_tool_futures(self, tracking, sampling_params, tokenizer): - """Poll and handle completed tool executions.""" + + def _poll_tool_futures(self, tracking, tokenizer): + """Poll and handle completed tool executions. + + Returns: + List of completed outputs that can't continue generation. + """ if not self.tools or not tracking["pending_tool_futures"]: - return + return [] dict_keys_to_delete = [] + completed_outputs = [] for req_id, (future, last_o, last_output) in tracking["pending_tool_futures"].items(): if not future.done(): continue - # Tool future is done, process it - tool_result = future.result() # Get the tool result + # Tool future is done, process it. + tool_result = future.result() last_prompt_token_ids = last_output.prompt_token_ids last_token_ids = last_o.token_ids @@ -534,34 +666,47 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer): can_make_new_request = True # Edge case 2: clip against per-request max_tokens - remaining = sampling_params.max_tokens - len(tracking["masks"][req_id]) - if remaining <= 0: - tool_output_token_ids = [] - elif len(tool_output_token_ids) > remaining: - tool_output_token_ids = tool_output_token_ids[:remaining] + req_sampling_params = self.request_metadata[req_id]["sampling_params"] + if req_sampling_params: + remaining = req_sampling_params.max_tokens - len(tracking["masks"][req_id]) + if remaining <= 0: + tool_output_token_ids = [] + elif len(tool_output_token_ids) > remaining: + tool_output_token_ids = tool_output_token_ids[:remaining] tracking["concat_outputs"][req_id].outputs[0].token_ids.extend(tool_output_token_ids) tracking["masks"][req_id].extend([0] * len(tool_output_token_ids)) - new_sample_tokens = sampling_params.max_tokens - len(tracking["masks"][req_id]) - can_make_new_request = can_make_new_request and new_sample_tokens > 0 + + if req_sampling_params: + new_sample_tokens = req_sampling_params.max_tokens - len(tracking["masks"][req_id]) + can_make_new_request = can_make_new_request and new_sample_tokens > 0 + else: + new_sample_tokens = 0 + can_make_new_request = False if can_make_new_request: - new_sampling_params = copy.deepcopy(sampling_params) + new_sampling_params = copy.deepcopy(req_sampling_params) new_sampling_params.max_tokens = new_sample_tokens try: - self.llm_engine.add_request( - req_id, vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output_token), new_sampling_params - ) + prompt = vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output_token, cache_salt=req_id) + self.llm_engine.add_request(req_id, prompt, new_sampling_params) + # Update the sampling params in request_metadata for the restarted request + if req_id in self.request_metadata: + self.request_metadata[req_id]["sampling_params"] = new_sampling_params except Exception as e: # Match original ToolUseLLM behavior - just log and continue self.logger.error(f"[_poll_tool_futures] Error adding request {req_id}: {e}") + completed_outputs.append(tracking["concat_outputs"][req_id]) + else: + completed_outputs.append(tracking["concat_outputs"][req_id]) dict_keys_to_delete.append(req_id) for req_id in dict_keys_to_delete: - if req_id in tracking["pending_tool_futures"]: - del tracking["pending_tool_futures"][req_id] + tracking["pending_tool_futures"].pop(req_id, None) + + return completed_outputs def init_process_group( self, @@ -662,6 +807,7 @@ def create_vllm_engines( max_model_len: int, vllm_gpu_memory_utilization: float = 0.9, single_gpu_mode: bool = False, + inference_batch_size: Optional[int] = None, pg: Optional[ray.util.placement_group] = None, vllm_enable_sleep=False, tools: Optional[Dict[str, Tool]] = None, @@ -670,6 +816,7 @@ def create_vllm_engines( results_queue=None, eval_results_queue=None, actor_manager=None, + inflight_updates: bool = False, ) -> list[LLMRayActor]: # Convert max_tool_calls to a dict mapping tool end strings to their limits if tools: @@ -750,6 +897,8 @@ def create_vllm_engines( actor_manager=actor_manager, tools=tools, max_tool_calls=max_tool_calls_dict, + inference_batch_size=inference_batch_size, + inflight_updates=inflight_updates, ) ) diff --git a/scripts/gantry_run_benchmark.sh b/scripts/gantry_run_benchmark.sh index 135ae62d9..abcca82cf 100755 --- a/scripts/gantry_run_benchmark.sh +++ b/scripts/gantry_run_benchmark.sh @@ -15,7 +15,7 @@ fi git_hash=$(git rev-parse --short HEAD) git_branch=$(git rev-parse --abbrev-ref HEAD) -model_name_or_path="hamishivi/qwen2_5_openthoughts2" \ +model_name_or_path="hamishivi/qwen2_5_openthoughts2" gantry run \ --name open_instruct-benchmark_generators \ diff --git a/scripts/train/debug/large_test_script.sh b/scripts/train/debug/large_test_script.sh index da349efcc..326a04370 100755 --- a/scripts/train/debug/large_test_script.sh +++ b/scripts/train/debug/large_test_script.sh @@ -26,6 +26,7 @@ uv run python mason.py \ --learning_rate 5e-7 \ --per_device_train_batch_size 1 \ --kl_estimator kl3 \ + --verbose True \ --dataset_mixer_list saurabh5/rlvr_acecoder_filtered ${num_prompts} saurabh5/open-code-reasoning-rlvr-stdio ${num_prompts} \ --dataset_mixer_list_splits train \ --dataset_mixer_eval_list saurabh5/rlvr_acecoder_filtered 8 saurabh5/open-code-reasoning-rlvr-stdio 8 \ @@ -34,7 +35,7 @@ uv run python mason.py \ --max_prompt_token_length 2048 \ --response_length 4096 \ --pack_length 20480 \ - --model_name_or_path Qwen/Qwen2.5-7B \ + --model_name_or_path hamishivi/qwen2_5_openthoughts2 \ --chat_template_name tulu_thinker \ --stop_strings "" \ --non_stop_penalty False \ @@ -45,6 +46,8 @@ uv run python mason.py \ --total_episodes 10_000 \ --deepspeed_stage 2 \ --num_learners_per_node 8 \ + --inflight_updates True \ + --update_every 1 \ --vllm_num_engines 8 \ --vllm_tensor_parallel_size 1 \ --lr_scheduler_type constant \ diff --git a/scripts/train/debug/single_gpu_on_beaker.sh b/scripts/train/debug/single_gpu_on_beaker.sh index dfeb6ec33..9db5af0f2 100755 --- a/scripts/train/debug/single_gpu_on_beaker.sh +++ b/scripts/train/debug/single_gpu_on_beaker.sh @@ -37,6 +37,7 @@ uv run python mason.py \ --apply_r1_style_format_reward \ --apply_verifiable_reward true \ --temperature 0.7 \ + --inflight_updates True \ --ground_truths_key ground_truth \ --chat_template_name r1_simple_chat_postpend_think \ --learning_rate 3e-7 \ @@ -49,6 +50,7 @@ uv run python mason.py \ --beta 0.01 \ --seed 3 \ --local_eval_every 1 \ + --vllm_enable_prefix_caching True \ --vllm_sync_backend gloo \ --vllm_gpu_memory_utilization 0.3 \ --save_traces \