diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 996c91cb2..3adf901b6 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -873,7 +873,7 @@ def broadcast_to_vllm(self): shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape refs = [ engine.update_weight.remote( - name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + name, dtype=str(param.dtype), shape=shape, empty_cache=count == num_params ) for engine in self.vllm_engines ] @@ -887,7 +887,7 @@ def broadcast_to_vllm(self): shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape refs = [ engine.update_weight.remote( - name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + name, dtype=str(param.dtype), shape=shape, empty_cache=count == num_params ) for engine in self.vllm_engines ] diff --git a/open_instruct/test_vllm_utils3.py b/open_instruct/test_vllm_utils3.py index 93b4096cc..236ab05e8 100644 --- a/open_instruct/test_vllm_utils3.py +++ b/open_instruct/test_vllm_utils3.py @@ -63,7 +63,7 @@ def create_mock_logprobs(token_ids): "dataset_index": 43039, "epoch_number": 0, "training_step": 1, - "prompt_tokens": 10, + "prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "start_time": 1000.0, } } @@ -75,7 +75,6 @@ def create_mock_logprobs(token_ids): result, is_eval = process_completed_request( request_id="train_1_43039", outs=[mock_request_output], - tracking={}, # Not used for this test current_time=1001.0, tools=tools, request_metadata=request_metadata, @@ -134,7 +133,7 @@ def create_mock_logprobs(token_ids): "dataset_index": 200, "epoch_number": 0, "training_step": 2, - "prompt_tokens": 5, + "prompt_token_ids": [1, 2, 3, 4, 5], "start_time": 2000.0, } } @@ -143,7 +142,6 @@ def create_mock_logprobs(token_ids): result, is_eval = process_completed_request( request_id="eval_2_200", outs=[mock_request_output], - tracking={}, # Not used for this test current_time=2000.5, tools=None, request_metadata=request_metadata, diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index 18b3726a1..c16e725c0 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -16,7 +16,6 @@ """This file is copied from https://github.com/OpenRLHF/OpenRLHF""" import asyncio -import dataclasses import os import queue import sys @@ -26,7 +25,7 @@ from collections import defaultdict from concurrent import futures from datetime import timedelta -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union import ray import torch @@ -45,7 +44,6 @@ default_pg_timeout, rendezvous, ) -from vllm.v1 import kv_cache_interface from vllm.v1.core import kv_cache_utils from open_instruct import logger_utils @@ -55,7 +53,10 @@ logger = logger_utils.setup_logger(__name__) -WEIGHT_UPDATE_SLEEP_INTERVAL_S = 0.1 +NUM_PREFETCH_WORKERS = 2 +NUM_TOOL_WORKERS = 20 +DRAIN_ACTIVE_TASKS_SLEEP_S = 1 +SHOULD_STOP_TIMEOUT_S = 0.1 def assert_threaded_actor(instance): @@ -83,6 +84,159 @@ def assert_threaded_actor(instance): return +def _truncate_tool_output_tokens( + tool_output_token_ids: List[int], + current_prompt_token_ids: List[int], + accumulated_tokens: List[int], + max_model_len: int, + max_tokens: int, + current_mask_len: int, +) -> Tuple[List[int], int, List[int]]: + prompt_and_tool_output = current_prompt_token_ids + accumulated_tokens + tool_output_token_ids + excess = len(prompt_and_tool_output) - max_model_len + if excess > 0: + tool_output_token_ids = tool_output_token_ids[:-excess] + + remaining = max_tokens - current_mask_len + if remaining <= 0: + return [], excess, prompt_and_tool_output + elif len(tool_output_token_ids) > remaining: + return tool_output_token_ids[:remaining], excess, prompt_and_tool_output + + return tool_output_token_ids, excess, prompt_and_tool_output + + +async def process_request_async( + actor: "LLMRayActor", + sub_request_id: str, + base_request_id: str, + prompt: vllm.TokensPrompt, + sampling_params: vllm.SamplingParams, +): + """Process a single async request with tool support, awaiting tools inline.""" + accumulated_tokens = [] + accumulated_logprobs = [] + masks = [] + num_calls = 0 + timeout = False + tool_error = "" + tool_output = "" + tool_runtime = 0.0 + tool_called = False + + current_prompt = prompt + current_prompt_token_ids = actor.request_metadata[base_request_id]["prompt_token_ids"] + current_sampling_params = sampling_params.clone() + final_prompt_token_ids = None + iteration = 0 + + while True: + iteration_request_id = f"{sub_request_id}_iter{iteration}" + outputs = [ + o + async for o in actor.llm_engine.generate(current_prompt, current_sampling_params, iteration_request_id) + if o.finished + ] + assert len(outputs) == 1, f"Expected exactly 1 output, got {len(outputs)} for request {iteration_request_id}" + request_output = outputs[0] + iteration += 1 + output = request_output.outputs[0] + + if final_prompt_token_ids is None: + final_prompt_token_ids = request_output.prompt_token_ids + + accumulated_tokens.extend(output.token_ids) + accumulated_logprobs.extend(output.logprobs) + masks.extend([1] * len(output.token_ids)) + + if not actor.tools or not actor.max_tool_calls: + break + + triggered_tool, stop_str = get_triggered_tool( + output.text, actor.tools, actor.max_tool_calls, num_calls, sampling_params + ) + if triggered_tool is None: + break + + assert actor.executor is not None, f"executor is None for request {sub_request_id}" + + loop = asyncio.get_running_loop() + tool_result = await loop.run_in_executor(actor.executor, triggered_tool, output.text) + + tool_called = True + num_calls += 1 + timeout = timeout or tool_result.timeout + tool_error += "" if tool_result.error is None else tool_result.error + tool_output += tool_result.output + tool_runtime += tool_result.runtime + + tool_output_token_ids = actor.llm_engine.tokenizer.encode( + "\n" + tool_result.output + "\n", add_special_tokens=False + ) + + tool_output_token_ids, excess, prompt_and_tool_output = _truncate_tool_output_tokens( + tool_output_token_ids, + current_prompt_token_ids, + accumulated_tokens, + actor.llm_engine.model_config.max_model_len, + sampling_params.max_tokens, + len(masks), + ) + + accumulated_tokens.extend(tool_output_token_ids) + accumulated_logprobs.extend( + [{token_id: types.SimpleNamespace(logprob=0.0)} for token_id in tool_output_token_ids] + ) + masks.extend([0] * len(tool_output_token_ids)) + + new_sample_tokens = sampling_params.max_tokens - len(masks) + if excess > 0 or new_sample_tokens <= 0: + break + + current_prompt = vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output, cache_salt=base_request_id) + current_prompt_token_ids = prompt_and_tool_output + final_prompt_token_ids = prompt_and_tool_output + current_sampling_params = sampling_params.clone() + current_sampling_params.max_tokens = new_sample_tokens + + complete_output = vllm.CompletionOutput( + index=split_request_id(sub_request_id)["request_index"], + text="", + token_ids=accumulated_tokens, + cumulative_logprob=output.cumulative_logprob, + logprobs=accumulated_logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + ) + + if actor.tools: + setattr(complete_output, "mask", masks) + setattr(complete_output, "num_calls", num_calls) + setattr(complete_output, "timeout", timeout) + setattr(complete_output, "tool_error", tool_error) + setattr(complete_output, "tool_output", tool_output) + setattr(complete_output, "tool_runtime", tool_runtime) + setattr(complete_output, "tool_called", tool_called) + + actor.active_tasks.pop(sub_request_id, None) + + actor.completion_queue.put( + { + "base_request_id": base_request_id, + "expected_n": actor.request_metadata[base_request_id]["original_sampling_params"].n, + "request_output": vllm.RequestOutput( + request_id=sub_request_id, + prompt=request_output.prompt, + prompt_token_ids=final_prompt_token_ids, + prompt_logprobs=request_output.prompt_logprobs, + outputs=[complete_output], + finished=True, + ), + "tools": actor.tools, + } + ) + + # Edited from: https://github.com/OpenRLHF/OpenRLHF/pull/971/files # Turns out Ray doesnt necessarily place bundles together, # so this function is used to get the bundle indices of a placement group @@ -101,36 +255,22 @@ def get_bundle_indices_list(placement_group: ray.util.placement_group) -> List[i return flattened_bundle_indices -def _init_tool_tracking(): - """Initialize tracking variables for tool mode.""" - return { - "num_calls": defaultdict(int), - "timeout": defaultdict(bool), - "tool_error": defaultdict(str), - "tool_output": defaultdict(str), - "tool_runtime": defaultdict(float), - "tool_called": defaultdict(bool), - "concat_outputs": {}, - "masks": defaultdict(list), - "pending_tool_futures": {}, - } - - def make_request_id(request: PromptRequest) -> str: """Generate a unique tracking key for a request.""" prefix = "eval" if request.is_eval else "train" return f"{prefix}_{request.training_step}_{request.dataset_index}" -def _extract_base_request_id(full_request_id: str) -> str: - """Extract base request ID by removing the sample suffix. +def split_request_id(full_request_id: str) -> dict: + """Split request ID into base ID and request index. - >>> _extract_base_request_id("train_1_43039_0") - 'train_1_43039' - >>> _extract_base_request_id("eval_5_12345_2") - 'eval_5_12345' + >>> split_request_id("train_1_43039_0") + {'base_id': 'train_1_43039', 'request_index': 0} + >>> split_request_id("eval_5_12345_2") + {'base_id': 'eval_5_12345', 'request_index': 2} """ - return "_".join(full_request_id.split("_")[:-1]) + parts = full_request_id.split("_") + return {"base_id": "_".join(parts[:-1]), "request_index": int(parts[-1])} def get_triggered_tool( @@ -161,46 +301,12 @@ def get_triggered_tool( return None, None -def _handle_output(output, tools, tracking, sampling_params, max_tool_calls, executor): - """ - Handle a finished output. Returns the output if it should be added to results, - or None if it's being held for tool processing. - - This is a free function to keep the processing logic separate from the actor state. - """ - if not tools: - return output - - assert len(output.outputs) <= 1, f"{len(output.outputs)=}" # In tool mode, sampling_params.n == 1 - o = output.outputs[0] - - if output.request_id in tracking["concat_outputs"]: - stored = tracking["concat_outputs"][output.request_id].outputs[0] - stored.token_ids.extend(o.token_ids) - stored.logprobs.extend(o.logprobs) - else: - tracking["concat_outputs"][output.request_id] = output - - tracking["masks"][output.request_id].extend([1] * len(o.token_ids)) - - tool, stop_str = get_triggered_tool( - o.text, tools, max_tool_calls, tracking["num_calls"][output.request_id], sampling_params - ) - if tool is None: - return output - - future = executor.submit(tool, o.text) - tracking["pending_tool_futures"][output.request_id] = (future, o, output) - return None - - -def process_completed_request(request_id, outs, tracking, current_time, tools, request_metadata): +def process_completed_request(request_id, outs, current_time, tools, request_metadata): """Process a completed request with all its samples and return the result. Args: request_id: The base request ID outs: List of vllm.RequestOutput objects for all sub-requests - tracking: Dictionary containing tool tracking information current_time: Current timestamp for performance metrics tools: Dictionary of available tools (may be None or empty) request_metadata: Dictionary containing metadata for all requests @@ -270,7 +376,7 @@ def process_completed_request(request_id, outs, tracking, current_time, tools, r dataset_index=metadata["dataset_index"], epoch_number=metadata["epoch_number"], token_statistics=TokenStatistics( - num_prompt_tokens=metadata["prompt_tokens"], + num_prompt_tokens=len(metadata["prompt_token_ids"]), num_response_tokens=total_generation_tokens, generation_time=current_time - metadata["start_time"], ), @@ -359,38 +465,43 @@ def init_process_group( return pg -def add_request( - request: PromptRequest, - llm_engine: vllm.LLMEngine, - tools: Dict[str, Tool], - request_metadata: dict, - vllm_active_requests: dict, -) -> int: - """Add a request to the LLM engine.""" +def _prefetch_worker(actor: "LLMRayActor") -> None: + while True: + if actor._should_stop() or len(actor.active_tasks) >= actor.inference_batch_size: + time.sleep(DRAIN_ACTIVE_TASKS_SLEEP_S) + continue + + request = actor.prompt_queue.get() + add_request(actor, request) + + +def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: request_id = make_request_id(request) + sampling_params = request.generation_config.clone() sampling_params.n = 1 # Use n=1 for tool processing - request_metadata[request_id] = { + + actor.request_metadata[request_id] = { "is_eval": request.is_eval, "dataset_index": request.dataset_index, "epoch_number": request.epoch_number, "training_step": request.training_step, "sampling_params": sampling_params, "original_sampling_params": request.generation_config, - "prompt_tokens": len(request.prompt), + "prompt_token_ids": list(request.prompt), "start_time": time.perf_counter(), } tokens_prompt = vllm.TokensPrompt(prompt_token_ids=request.prompt, cache_salt=request_id) + for j in range(request.generation_config.n): - sub_sampling_params = sampling_params.clone() # Already has n=1 + sub_sampling_params = sampling_params.clone() if request.generation_config.seed is not None: sub_sampling_params.seed = request.generation_config.seed + j sub_request_id = f"{request_id}_{j}" - llm_engine.add_request(sub_request_id, tokens_prompt, sub_sampling_params) - vllm_active_requests.add(sub_request_id) - - return request.generation_config.n + actor.active_tasks[sub_request_id] = asyncio.run_coroutine_threadsafe( + process_request_async(actor, sub_request_id, request_id, tokens_prompt, sub_sampling_params), actor.loop + ) class LLMRayActor: @@ -413,21 +524,12 @@ def __init__( assert_threaded_actor(self) self._init_config(tools, max_tool_calls, inference_batch_size, inflight_updates) self._init_queues(prompt_queue, results_queue, eval_results_queue, actor_manager) + self._init_executor() noset_visible_devices = kwargs.pop("noset_visible_devices") distributed_executor_backend = kwargs.get("distributed_executor_backend") self._setup_gpu_visibility(noset_visible_devices, distributed_executor_backend) - - self._setup_engine_args(args, bundle_indices, kwargs) - - self.tracking = _init_tool_tracking() - self.request_outputs = {} - self._threads_started = threading.Event() - - max_workers = 22 if self.tools else 2 # 2 for background threads + 20 for tool execution if tools enabled - self.executor = futures.ThreadPoolExecutor(max_workers=max_workers) - self._prefetch_future = self.executor.submit(self._prefetch_worker) - self._process_future = self.executor.submit(self._process_from_queue) + self._setup_and_start_async_engine(args, bundle_indices, kwargs) def _init_config( self, @@ -436,23 +538,30 @@ def _init_config( inference_batch_size: Optional[int], inflight_updates: bool, ) -> None: - self.logger = logger_utils.setup_logger(__name__) self.tools = tools or {} self.max_tool_calls = max_tool_calls or {} self.inference_batch_size = inference_batch_size self.inflight_updates = inflight_updates self.request_metadata = {} - self.vllm_active_requests = set() + self.active_tasks = {} + self.request_outputs = {} def _init_queues(self, prompt_queue, results_queue, eval_results_queue, actor_manager) -> None: + self.completion_queue = queue.Queue() self.prompt_queue = prompt_queue self.results_queue = results_queue self.eval_results_queue = eval_results_queue self.actor_manager = actor_manager + # For caching should_stop status. self._last_should_stop_update = float("-inf") self._should_stop_value = False - self._should_stop_timeout_s = 5 + + def _init_executor(self) -> None: + max_workers = NUM_PREFETCH_WORKERS + (NUM_TOOL_WORKERS if self.tools else 0) + self.executor = futures.ThreadPoolExecutor(max_workers=max_workers) + self._prefetch_future = self.executor.submit(_prefetch_worker, self) + self._process_future = self.executor.submit(self.process_from_queue) def _setup_gpu_visibility(self, noset_visible_devices: bool, distributed_executor_backend: str) -> None: # a hack to make the script work. @@ -467,22 +576,38 @@ def _setup_gpu_visibility(self, noset_visible_devices: bool, distributed_executo # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set. os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0]) - def _setup_engine_args(self, args, bundle_indices, kwargs) -> None: + def _setup_and_start_async_engine(self, args, bundle_indices, kwargs) -> None: num_gpus = kwargs.pop("num_gpus") if bundle_indices is not None: os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(num_gpus) os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) logger.debug(f"creating LLM with bundle_indices={bundle_indices}") - engine_args = vllm.EngineArgs(*args, **kwargs) - # Log stats causes a crash in the engine at assert outputs.scheduler_stats is not None when we call step() and there is nothing to step. + engine_args = vllm.AsyncEngineArgs(*args, **kwargs) engine_args.disable_log_stats = True - # Cascade attention has known performance issues: https://github.com/vllm-project/vllm/issues/17652 engine_args.disable_cascade_attn = True - self.llm_engine = vllm.LLMEngine.from_engine_args(engine_args) + init_complete = threading.Event() + self.loop = None + self.llm_engine = None + + async def _init_engine(): + running_loop = asyncio.get_running_loop() + assert running_loop == self.loop, f"Loop mismatch! running={running_loop}, actor.loop={self.loop}" + return vllm.AsyncLLMEngine.from_engine_args(engine_args, start_engine_loop=False) + + def _run_loop(): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.llm_engine = self.loop.run_until_complete(_init_engine()) + init_complete.set() + self.loop.run_forever() - def get_model_dims_dict(self): + self.loop_thread = threading.Thread(target=_run_loop, daemon=True) + self.loop_thread.start() + init_complete.wait() + + def get_model_dims_dict(self) -> Dict[str, int]: """Get only the model dimensions as a simple dict without loading weights.""" model_config = self.llm_engine.model_config parallel_config = self.llm_engine.vllm_config.parallel_config @@ -501,9 +626,9 @@ def get_model_dims_dict(self): } def _should_stop(self) -> bool: - if (time.perf_counter() - self._last_should_stop_update) > self._should_stop_timeout_s: + if (time.perf_counter() - self._last_should_stop_update) > SHOULD_STOP_TIMEOUT_S: should_stop_ref = self.actor_manager.should_stop.remote() - ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1) + ready_refs, _ = ray.wait([should_stop_ref], timeout=SHOULD_STOP_TIMEOUT_S) if ready_refs: self._should_stop_value = ray.get(ready_refs[0]) self._last_should_stop_update = time.perf_counter() @@ -511,481 +636,140 @@ def _should_stop(self) -> bool: ray.cancel(should_stop_ref) return self._should_stop_value - def _prefetch_worker(self, sleep_length_s: int = 1): - """Background worker that prefetches requests until we have enough buffered.""" - self._threads_started.set() - while True: - if not self.inflight_updates and self._should_stop(): - time.sleep(sleep_length_s) - continue - current_unfinished = self.llm_engine.get_num_unfinished_requests() - if current_unfinished >= self.inference_batch_size: - time.sleep(sleep_length_s) - continue - try: - request = self.prompt_queue.get(timeout=0.1) - add_request( - request, - self.llm_engine, - self.tools, - request_metadata=self.request_metadata, - vllm_active_requests=self.vllm_active_requests, - ) - except queue.Empty: - continue - - def _insert_result_to_queue(self, result, is_eval: bool): - """Insert result into the appropriate queue with blocking put.""" - results_queue = self.eval_results_queue if is_eval else self.results_queue - results_queue.put(result) + def _accumulate_sub_request(self, sub_request: dict) -> None: + base_request_id = sub_request["base_request_id"] + expected_n = sub_request["expected_n"] - def _process_from_queue(self, timeout: float = 60.0): - """Run generation loop using LLMEngine directly, with optional tool support. + if base_request_id not in self.request_outputs: + self.request_outputs[base_request_id] = { + "outputs": [], + "expected_n": expected_n, + "tools": sub_request["tools"], + } - Runs continuously in a background thread, processing requests from the engine. + self.request_outputs[base_request_id]["outputs"].append(sub_request["request_output"]) - Returns: - int: Number of requests processed - """ - total_processed = 0 - iteration_count = 0 + is_complete = len(self.request_outputs[base_request_id]["outputs"]) == expected_n + if is_complete: + self._finalize_completed_request(base_request_id) - while True: - iteration_count += 1 - - # Health check: ensure prefetch worker is alive. This will raise if it has crashed. - if self._prefetch_future.done(): - self._prefetch_future.result() - - self._poll_tool_futures(self.tracking, self.llm_engine.tokenizer) - current_time = time.perf_counter() - if self.llm_engine.has_unfinished_requests(): - for output in [o for o in self.llm_engine.step() if o.finished]: - # Fix the index field for all sub-requests - # When we have n>1, we create sub-requests with IDs like - # train_3_12_0, train_3_12_1, etc. But vLLM creates CompletionOutputs with index=0 - # for all of them (since each sub-request has n=1). We need to fix this. - # Extract the actual index from the sub-request ID - parts = output.request_id.rsplit("_", 1) - assert len(parts) == 2 and parts[1].isdigit(), ( - f"Wrong request id format ({output.request_id}), should be request_id _ sub_request_index" - ) - - # Fix the index on the CompletionOutput - correct_index = int(parts[1]) - output.outputs = [dataclasses.replace(o, index=correct_index) for o in output.outputs] - base_req_id = _extract_base_request_id(output.request_id) - result = _handle_output( - output, - self.tools, - self.tracking, - self.request_metadata[base_req_id]["sampling_params"], - self.max_tool_calls, - self.executor, - ) - - # Result is None when we do more tool processing. - if result is None: - # Request went to tools - remove from vllm_active_requests since it's no longer in vLLM - self.vllm_active_requests.discard(output.request_id) - else: - # Sub-request is done (no more tool calls) - if output.request_id in self.tracking["concat_outputs"]: - complete_output = self.tracking["concat_outputs"][output.request_id].outputs[0] - else: - complete_output = result.outputs[0] - - # Remove from vllm_active_requests BEFORE calling _finalize_sub_request - # to avoid deadlock in _maybe_process_and_insert - self.vllm_active_requests.discard(output.request_id) - total_processed += self._finalize_sub_request( - output.request_id, output, complete_output, current_time - ) - if self.llm_engine.get_num_unfinished_requests() == 0: - time.sleep(1) - - return total_processed - - def _maybe_process_and_insert( - self, - request_id: str, - request_outputs: Dict[str, List[vllm.RequestOutput]], - tracking: Dict[str, Any], - current_time: float, - ) -> int: - """Check if we have N requests for request_id, process them, and insert results in queue. - - Returns: - int: Number of requests processed (0 or 1). - """ - expected_n = self.request_metadata[request_id]["original_sampling_params"].n - - # Check if we have the base request in request_outputs - if request_id not in request_outputs: - return 0 - - available_outputs = request_outputs[request_id].outputs - if len(available_outputs) < expected_n: - return 0 - - needed_ids = [f"{request_id}_{j}" for j in range(expected_n)] - active_sub_requests = [sub_id for sub_id in needed_ids if sub_id in self.vllm_active_requests] - if active_sub_requests: - return 0 - - has_pending_tools = any(sub_id in tracking.get("pending_tool_futures", {}) for sub_id in needed_ids) - if has_pending_tools: - return 0 - - # At this point we have all outputs ready. Build ordered outputs for processing. - # First organize available_outputs into a dictionary for O(1) lookup - outputs_by_index = {o.index: o for o in available_outputs if hasattr(o, "index")} - - # Verify we have all required outputs before proceeding - if len(outputs_by_index) != expected_n or any(j not in outputs_by_index for j in range(expected_n)): - logger.warning( - f"Incomplete or malformed outputs for {request_id}. " - f"Expected {expected_n} samples, got indices {sorted(outputs_by_index.keys())}. Skipping." - ) - return 0 - - ordered_outs: List[vllm.RequestOutput] = [] - for j in range(expected_n): - # Create a RequestOutput wrapper for each CompletionOutput - ordered_outs.append( - vllm.RequestOutput( - request_id=f"{request_id}_{j}", - prompt=request_outputs[request_id].prompt, - prompt_token_ids=request_outputs[request_id].prompt_token_ids, - prompt_logprobs=request_outputs[request_id].prompt_logprobs, - outputs=[outputs_by_index[j]], - finished=True, - ) - ) + def _finalize_completed_request(self, base_request_id: str) -> None: + outputs = self.request_outputs[base_request_id]["outputs"] + ordered_outs = sorted(outputs, key=lambda x: split_request_id(x.request_id)["request_index"]) - # Remove the base entry from request_outputs to prevent growth. - request_outputs.pop(request_id, None) + current_time = time.perf_counter() result, is_eval = process_completed_request( - request_id, ordered_outs, tracking, current_time, self.tools, self.request_metadata + base_request_id, + ordered_outs, + current_time, + self.request_outputs[base_request_id]["tools"], + self.request_metadata, ) - self._insert_result_to_queue(result, is_eval=is_eval) - self._cleanup_request_data(request_id, tracking) - return 1 - - def _has_pending_tool_futures_for_request(self, request_id: str, tracking: Dict[str, Any]) -> bool: - """Check if there are any pending tool futures for a given base request ID.""" - if not self.tools or not tracking["pending_tool_futures"]: - return False - - # Check if any pending tool futures belong to this base request - for req_id in tracking["pending_tool_futures"]: - if _extract_base_request_id(req_id) == request_id: - return True - return False - - def _has_active_sub_requests_for_base_id(self, base_request_id: str) -> bool: - """Check if there are any active sub-requests in vLLM for a given base request ID.""" - # Check if any active request IDs belong to our base request - for req_id in self.vllm_active_requests: - if _extract_base_request_id(req_id) == base_request_id: - return True - return False - - def _cleanup_request_data(self, request_id: str, tracking: Dict[str, Any]): - """Clean up metadata and tracking data for a completed request.""" - # Check if there are still pending tool futures for this request - if self._has_pending_tool_futures_for_request(request_id, tracking): - # Don't clean up metadata yet - tool futures still need it - return - - # Check if there are still active sub-requests in vLLM for this base request - if self._has_active_sub_requests_for_base_id(request_id): - # Don't clean up metadata yet - active requests still need it - return - - # Remove request metadata only after both conditions are met: - # 1. No pending tool futures for this request - # 2. No active sub-requests in vLLM for this base request - self.request_metadata.pop(request_id, None) - - # Clean up tracking data for all sub-requests of this request - if self.tools: - # Find all sub-request IDs that belong to this base request - sub_request_ids = [ - k for k in tracking["concat_outputs"].keys() if _extract_base_request_id(k) == request_id - ] - - for sub_req_id in sub_request_ids: - # Clean up tracking dictionaries - tracking["concat_outputs"].pop(sub_req_id, None) - tracking["masks"].pop(sub_req_id, None) - tracking["num_calls"].pop(sub_req_id, None) - tracking["timeout"].pop(sub_req_id, None) - tracking["tool_error"].pop(sub_req_id, None) - tracking["tool_output"].pop(sub_req_id, None) - tracking["tool_runtime"].pop(sub_req_id, None) - tracking["tool_called"].pop(sub_req_id, None) - # Note: pending_tool_futures should already be cleaned by _poll_tool_futures - - def _finalize_sub_request(self, sub_request_id, request_output_for_prompts, complete_output, current_time): - """ - Finalize a completed sub-request by moving it to request_outputs and processing if ready. - - Args: - sub_request_id: The sub-request ID (e.g., "train_1_43039_2") - request_output_for_prompts: RequestOutput containing prompt info - complete_output: The CompletionOutput to add - current_time: Current timestamp for processing - - Returns: - Number of processed requests (0 or 1) - """ - base_request_id = _extract_base_request_id(sub_request_id) - - # Extract the sub-request index from the sub_request_id and set it on the CompletionOutput - # This is needed to properly identify which sub-request each output belongs to. - # MUST be done BEFORE adding to request_outputs so that - # _maybe_process_and_insert can find the index field when checking completeness. - if "_" in sub_request_id: - # Extract index from sub_request_id like "train_1_43039_2" -> 2 - parts = sub_request_id.rsplit("_", 1) - if len(parts) == 2 and parts[1].isdigit(): - # Create new CompletionOutput with corrected index - complete_output = dataclasses.replace(complete_output, index=int(parts[1])) - - # If tools are enabled, attach tool metadata to the output - if self.tools: - # Set tool metadata attributes on the output - setattr( - complete_output, - "mask", - self.tracking["masks"].get(sub_request_id, [1] * len(complete_output.token_ids)), - ) - setattr(complete_output, "num_calls", self.tracking["num_calls"].get(sub_request_id, 0)) - setattr(complete_output, "timeout", self.tracking["timeout"].get(sub_request_id, False)) - setattr(complete_output, "tool_error", self.tracking["tool_error"].get(sub_request_id, "")) - setattr(complete_output, "tool_output", self.tracking["tool_output"].get(sub_request_id, "")) - setattr(complete_output, "tool_runtime", self.tracking["tool_runtime"].get(sub_request_id, 0.0)) - setattr(complete_output, "tool_called", self.tracking["tool_called"].get(sub_request_id, False)) - - # Initialize request_outputs entry if needed - if base_request_id not in self.request_outputs: - self.request_outputs[base_request_id] = vllm.RequestOutput( - request_id=base_request_id, - prompt=request_output_for_prompts.prompt, - prompt_token_ids=request_output_for_prompts.prompt_token_ids, - prompt_logprobs=request_output_for_prompts.prompt_logprobs, - outputs=[], - finished=True, - ) - - # Add the completion output (with index field already set if needed) - self.request_outputs[base_request_id].outputs.append(complete_output) - - # Try to process and insert if we have all expected outputs - processed = self._maybe_process_and_insert(base_request_id, self.request_outputs, self.tracking, current_time) - - return processed - - def _poll_tool_futures(self, tracking, tokenizer): - """Poll and handle completed tool executions.""" - if not self.tools or not tracking["pending_tool_futures"]: - return [] - - dict_keys_to_delete = [] - completed_outputs = [] - - for req_id, (future, last_o, last_output) in list(tracking["pending_tool_futures"].items()): - if not future.done(): - continue - # Tool future is done, process it - tool_result = future.result() # Get the tool result + self.request_outputs.pop(base_request_id) + self.request_metadata.pop(base_request_id, None) - # Get sampling params from request metadata for this request - base_req_id = _extract_base_request_id(req_id) - sampling_params = self.request_metadata[base_req_id]["sampling_params"] + results_queue = self.eval_results_queue if is_eval else self.results_queue + results_queue.put(result) - last_prompt_token_ids = last_output.prompt_token_ids - last_token_ids = last_o.token_ids - tool_output_token_ids = tokenizer.encode( - "\n" + tool_result.output + "\n", add_special_tokens=False - ) - tracking["timeout"][req_id] = tool_result.timeout - tracking["tool_error"][req_id] += "" if tool_result.error is None else tool_result.error - tracking["tool_output"][req_id] += tool_result.output - tracking["tool_runtime"][req_id] += tool_result.runtime - tracking["tool_called"][req_id] = True - - # Edge case 1: clip against model context length - prompt_and_tool_output_token = last_prompt_token_ids + last_token_ids + tool_output_token_ids - tracking["num_calls"][req_id] += 1 - excess = len(prompt_and_tool_output_token) - self.llm_engine.model_config.max_model_len - if excess > 0: - tool_output_token_ids = tool_output_token_ids[:-excess] - can_make_new_request = False - else: - can_make_new_request = True - - # Edge case 2: clip against per-request max_tokens - remaining = sampling_params.max_tokens - len(tracking["masks"][req_id]) - if remaining <= 0: - tool_output_token_ids = [] - elif len(tool_output_token_ids) > remaining: - tool_output_token_ids = tool_output_token_ids[:remaining] - - # Extend token_ids and logprobs for tool output tokens so lengths stay aligned - concat_out = tracking["concat_outputs"][req_id].outputs[0] - concat_out.token_ids.extend(tool_output_token_ids) - # use placeholder logprobs for new tokens - # TODO: can we do something fancier here, or allow it? - concat_out.logprobs.extend([{tid: types.SimpleNamespace(logprob=0.0)} for tid in tool_output_token_ids]) - tracking["masks"][req_id].extend([0] * len(tool_output_token_ids)) - new_sample_tokens = sampling_params.max_tokens - len(tracking["masks"][req_id]) - can_make_new_request = can_make_new_request and new_sample_tokens > 0 - - if can_make_new_request: - new_sampling_params = sampling_params.clone() - new_sampling_params.max_tokens = new_sample_tokens - - try: - self.llm_engine.add_request( - req_id, vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output_token), new_sampling_params - ) - # Track tool continuation request as active - base_req_id = _extract_base_request_id(req_id) - if base_req_id in self.request_metadata: - self.vllm_active_requests.add(req_id) - - except Exception as e: - # Match original ToolUseLLM behavior - just log and continue - logger.error(f"[_poll_tool_futures] Error adding request {req_id}: {e}") - else: - # Can't make a new request (hit limits), finalize this sub-request - base_req_id = _extract_base_request_id(req_id) - - # Log the state before finalizing - other_pending = [ - other_id - for other_id in tracking["pending_tool_futures"] - if _extract_base_request_id(other_id) == base_req_id and other_id != req_id - ] - logger.info( - f"[_poll_tool_futures] Finalizing {req_id} (can't continue). " - f"Other pending tools for {base_req_id}: {other_pending}" - ) - - # Remove from pending_tool_futures BEFORE finalization to ensure consistent state - # This prevents the cleanup logic from seeing this as a pending tool future - tracking["pending_tool_futures"].pop(req_id, None) - - complete_output = tracking["concat_outputs"][req_id].outputs[0] - current_time = time.perf_counter() - self._finalize_sub_request(req_id, last_output, complete_output, current_time) - # Don't add to dict_keys_to_delete since we already removed it - continue - dict_keys_to_delete.append(req_id) - - # Remove the futures we just processed; do NOT clean up metadata here. - for req_id in dict_keys_to_delete: - tracking["pending_tool_futures"].pop(req_id, None) - - return completed_outputs + def process_from_queue(self) -> None: + while True: + sub_request = self.completion_queue.get() + self._accumulate_sub_request(sub_request) def init_process_group( self, - master_address, - master_port, - rank_offset, - world_size, - group_name, - backend, - use_ray=False, - timeout_minutes=120, - ): - return self.llm_engine.collective_rpc( - "init_process_group", - args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray, timeout_minutes), + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str, + use_ray: bool = False, + timeout_minutes: int = 120, + ) -> None: + future = asyncio.run_coroutine_threadsafe( + self.llm_engine.collective_rpc( + "init_process_group", + args=( + master_address, + master_port, + rank_offset, + world_size, + group_name, + backend, + use_ray, + timeout_minutes, + ), + ), + self.loop, ) + return future.result(timeout=timeout_minutes * 60) + + def _run_async(self, coro: Awaitable[Any]) -> Any: + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + return future.result() def _prepare_weight_update(self, name: str, dtype: str) -> None: - # First, drain all the requests when appropriate: - while not self.inflight_updates: - pending_tools = len(self.tracking["pending_tool_futures"]) - unfinished = self.llm_engine.get_num_unfinished_requests() - # if a background thread is dead, raise an error. + # Wait for all active requests to complete. + while not self.inflight_updates and len(self.active_tasks) > 0: self.check_background_threads() + time.sleep(DRAIN_ACTIVE_TASKS_SLEEP_S) - if pending_tools == 0 and unfinished == 0: - break - - time.sleep(WEIGHT_UPDATE_SLEEP_INTERVAL_S) - # Then, check that the dtypes match. expected_dtype = str(self.llm_engine.model_config.dtype) - assert str(dtype) == expected_dtype, ( - f"Mismatched dtype for {name}: received {dtype!r}, expected {expected_dtype!r}" - ) + assert dtype == expected_dtype, f"Mismatched dtype for {name}: received {dtype!r}, expected {expected_dtype!r}" def update_weight(self, name: str, dtype: str, shape: Tuple[int, ...], empty_cache: bool = False) -> None: self._prepare_weight_update(name, dtype) - return self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) + return self._run_async(self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))) def update_weight_cuda_ipc( self, name: str, dtype: str, shape: Tuple[int, ...], ipc_handles: List[Any], empty_cache: bool = False ) -> None: self._prepare_weight_update(name, dtype) - return self.llm_engine.collective_rpc( - "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache) + return self._run_async( + self.llm_engine.collective_rpc( + "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache) + ) ) - def reset_prefix_cache(self): - self.llm_engine.reset_prefix_cache() - - def sleep(self, level=1): - self.llm_engine.sleep(level=level) + def reset_prefix_cache(self) -> None: + return self._run_async(self.llm_engine.reset_prefix_cache()) - def wake_up(self, tags: Optional[list[str]] = None): - self.llm_engine.wake_up(tags) - - def ready(self): - self._threads_started.wait(timeout=30) + def ready(self) -> bool: return True - def check_background_threads(self): + def check_background_threads(self) -> None: if self._prefetch_future.done(): self._prefetch_future.result() if self._process_future.done(): self._process_future.result() + active_tasks = list(self.active_tasks.items()) + for task_id, task in active_tasks: + if task.done(): + task.result() + if not self.loop_thread.is_alive(): + raise RuntimeError( + "vLLM engine loop thread has died. Check logs for errors in EngineCore or async engine." + ) - def get_kv_cache_info(self): + def get_kv_cache_info(self) -> int: """Get KV cache max concurrency from the vLLM engine.""" - kv_cache_specs = self.llm_engine.model_executor.get_kv_cache_specs() - kv_cache_spec = kv_cache_specs[0] - - page_size = kv_cache_utils.get_uniform_page_size(kv_cache_spec) + kv_cache_specs = self._run_async(self.llm_engine.collective_rpc("get_kv_cache_spec")) vllm_config = self.llm_engine.vllm_config gpu_memory_utilization = vllm_config.cache_config.gpu_memory_utilization total_gpu_memory = torch.cuda.get_device_properties(0).total_memory available_memory = int(gpu_memory_utilization * total_gpu_memory) - num_blocks = kv_cache_utils.get_num_blocks(vllm_config, len(kv_cache_spec), available_memory, page_size) + kv_cache_groups = kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_specs[0]) - per_layer_size = page_size * num_blocks - kv_cache_tensors = [ - kv_cache_interface.KVCacheTensor(size=per_layer_size, shared_by=[layer_name]) - for layer_name in kv_cache_spec - ] - - kv_cache_groups = kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_spec) - - kv_cache_config = kv_cache_interface.KVCacheConfig( - num_blocks=num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=kv_cache_groups - ) - max_concurrency = kv_cache_utils.get_max_concurrency_for_kv_cache_config( - self.llm_engine.vllm_config, kv_cache_config + kv_cache_config = kv_cache_utils.get_kv_cache_config_from_groups( + vllm_config, kv_cache_groups, kv_cache_specs[0], available_memory ) + max_concurrency = kv_cache_utils.get_max_concurrency_for_kv_cache_config(vllm_config, kv_cache_config) + return int(max_concurrency) @@ -1045,7 +829,10 @@ def create_vllm_engines( max_tool_calls_dict = {} vllm_engines = [] - distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray" + if tensor_parallel_size == 1: + distributed_executor_backend = "uni" + else: + distributed_executor_backend = "ray" use_hybrid_engine = pg is not None num_gpus = int(tensor_parallel_size == 1) if use_hybrid_engine and tensor_parallel_size == 1 and single_gpu_mode: @@ -1080,7 +867,6 @@ def create_vllm_engines( num_cpus=num_gpus, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, - # VLLM v1 multiprocessing is required due to https://github.com/vllm-project/vllm/issues/15349 runtime_env=ray.runtime_env.RuntimeEnv( env_vars={"VLLM_ENABLE_V1_MULTIPROCESSING": "0", "TORCH_CUDA_ARCH_LIST": get_cuda_arch_list()} ),