diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py
index 996c91cb2..3adf901b6 100644
--- a/open_instruct/grpo_fast.py
+++ b/open_instruct/grpo_fast.py
@@ -873,7 +873,7 @@ def broadcast_to_vllm(self):
shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape
refs = [
engine.update_weight.remote(
- name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
+ name, dtype=str(param.dtype), shape=shape, empty_cache=count == num_params
)
for engine in self.vllm_engines
]
@@ -887,7 +887,7 @@ def broadcast_to_vllm(self):
shape = param.shape if self.args.deepspeed_stage != 3 else param.ds_shape
refs = [
engine.update_weight.remote(
- name, dtype=param.dtype, shape=shape, empty_cache=count == num_params
+ name, dtype=str(param.dtype), shape=shape, empty_cache=count == num_params
)
for engine in self.vllm_engines
]
diff --git a/open_instruct/test_vllm_utils3.py b/open_instruct/test_vllm_utils3.py
index 93b4096cc..236ab05e8 100644
--- a/open_instruct/test_vllm_utils3.py
+++ b/open_instruct/test_vllm_utils3.py
@@ -63,7 +63,7 @@ def create_mock_logprobs(token_ids):
"dataset_index": 43039,
"epoch_number": 0,
"training_step": 1,
- "prompt_tokens": 10,
+ "prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"start_time": 1000.0,
}
}
@@ -75,7 +75,6 @@ def create_mock_logprobs(token_ids):
result, is_eval = process_completed_request(
request_id="train_1_43039",
outs=[mock_request_output],
- tracking={}, # Not used for this test
current_time=1001.0,
tools=tools,
request_metadata=request_metadata,
@@ -134,7 +133,7 @@ def create_mock_logprobs(token_ids):
"dataset_index": 200,
"epoch_number": 0,
"training_step": 2,
- "prompt_tokens": 5,
+ "prompt_token_ids": [1, 2, 3, 4, 5],
"start_time": 2000.0,
}
}
@@ -143,7 +142,6 @@ def create_mock_logprobs(token_ids):
result, is_eval = process_completed_request(
request_id="eval_2_200",
outs=[mock_request_output],
- tracking={}, # Not used for this test
current_time=2000.5,
tools=None,
request_metadata=request_metadata,
diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py
index 18b3726a1..c16e725c0 100644
--- a/open_instruct/vllm_utils3.py
+++ b/open_instruct/vllm_utils3.py
@@ -16,7 +16,6 @@
"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""
import asyncio
-import dataclasses
import os
import queue
import sys
@@ -26,7 +25,7 @@
from collections import defaultdict
from concurrent import futures
from datetime import timedelta
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
import ray
import torch
@@ -45,7 +44,6 @@
default_pg_timeout,
rendezvous,
)
-from vllm.v1 import kv_cache_interface
from vllm.v1.core import kv_cache_utils
from open_instruct import logger_utils
@@ -55,7 +53,10 @@
logger = logger_utils.setup_logger(__name__)
-WEIGHT_UPDATE_SLEEP_INTERVAL_S = 0.1
+NUM_PREFETCH_WORKERS = 2
+NUM_TOOL_WORKERS = 20
+DRAIN_ACTIVE_TASKS_SLEEP_S = 1
+SHOULD_STOP_TIMEOUT_S = 0.1
def assert_threaded_actor(instance):
@@ -83,6 +84,159 @@ def assert_threaded_actor(instance):
return
+def _truncate_tool_output_tokens(
+ tool_output_token_ids: List[int],
+ current_prompt_token_ids: List[int],
+ accumulated_tokens: List[int],
+ max_model_len: int,
+ max_tokens: int,
+ current_mask_len: int,
+) -> Tuple[List[int], int, List[int]]:
+ prompt_and_tool_output = current_prompt_token_ids + accumulated_tokens + tool_output_token_ids
+ excess = len(prompt_and_tool_output) - max_model_len
+ if excess > 0:
+ tool_output_token_ids = tool_output_token_ids[:-excess]
+
+ remaining = max_tokens - current_mask_len
+ if remaining <= 0:
+ return [], excess, prompt_and_tool_output
+ elif len(tool_output_token_ids) > remaining:
+ return tool_output_token_ids[:remaining], excess, prompt_and_tool_output
+
+ return tool_output_token_ids, excess, prompt_and_tool_output
+
+
+async def process_request_async(
+ actor: "LLMRayActor",
+ sub_request_id: str,
+ base_request_id: str,
+ prompt: vllm.TokensPrompt,
+ sampling_params: vllm.SamplingParams,
+):
+ """Process a single async request with tool support, awaiting tools inline."""
+ accumulated_tokens = []
+ accumulated_logprobs = []
+ masks = []
+ num_calls = 0
+ timeout = False
+ tool_error = ""
+ tool_output = ""
+ tool_runtime = 0.0
+ tool_called = False
+
+ current_prompt = prompt
+ current_prompt_token_ids = actor.request_metadata[base_request_id]["prompt_token_ids"]
+ current_sampling_params = sampling_params.clone()
+ final_prompt_token_ids = None
+ iteration = 0
+
+ while True:
+ iteration_request_id = f"{sub_request_id}_iter{iteration}"
+ outputs = [
+ o
+ async for o in actor.llm_engine.generate(current_prompt, current_sampling_params, iteration_request_id)
+ if o.finished
+ ]
+ assert len(outputs) == 1, f"Expected exactly 1 output, got {len(outputs)} for request {iteration_request_id}"
+ request_output = outputs[0]
+ iteration += 1
+ output = request_output.outputs[0]
+
+ if final_prompt_token_ids is None:
+ final_prompt_token_ids = request_output.prompt_token_ids
+
+ accumulated_tokens.extend(output.token_ids)
+ accumulated_logprobs.extend(output.logprobs)
+ masks.extend([1] * len(output.token_ids))
+
+ if not actor.tools or not actor.max_tool_calls:
+ break
+
+ triggered_tool, stop_str = get_triggered_tool(
+ output.text, actor.tools, actor.max_tool_calls, num_calls, sampling_params
+ )
+ if triggered_tool is None:
+ break
+
+ assert actor.executor is not None, f"executor is None for request {sub_request_id}"
+
+ loop = asyncio.get_running_loop()
+ tool_result = await loop.run_in_executor(actor.executor, triggered_tool, output.text)
+
+ tool_called = True
+ num_calls += 1
+ timeout = timeout or tool_result.timeout
+ tool_error += "" if tool_result.error is None else tool_result.error
+ tool_output += tool_result.output
+ tool_runtime += tool_result.runtime
+
+ tool_output_token_ids = actor.llm_engine.tokenizer.encode(
+ "\n", add_special_tokens=False
+ )
+
+ tool_output_token_ids, excess, prompt_and_tool_output = _truncate_tool_output_tokens(
+ tool_output_token_ids,
+ current_prompt_token_ids,
+ accumulated_tokens,
+ actor.llm_engine.model_config.max_model_len,
+ sampling_params.max_tokens,
+ len(masks),
+ )
+
+ accumulated_tokens.extend(tool_output_token_ids)
+ accumulated_logprobs.extend(
+ [{token_id: types.SimpleNamespace(logprob=0.0)} for token_id in tool_output_token_ids]
+ )
+ masks.extend([0] * len(tool_output_token_ids))
+
+ new_sample_tokens = sampling_params.max_tokens - len(masks)
+ if excess > 0 or new_sample_tokens <= 0:
+ break
+
+ current_prompt = vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output, cache_salt=base_request_id)
+ current_prompt_token_ids = prompt_and_tool_output
+ final_prompt_token_ids = prompt_and_tool_output
+ current_sampling_params = sampling_params.clone()
+ current_sampling_params.max_tokens = new_sample_tokens
+
+ complete_output = vllm.CompletionOutput(
+ index=split_request_id(sub_request_id)["request_index"],
+ text="",
+ token_ids=accumulated_tokens,
+ cumulative_logprob=output.cumulative_logprob,
+ logprobs=accumulated_logprobs,
+ finish_reason=output.finish_reason,
+ stop_reason=output.stop_reason,
+ )
+
+ if actor.tools:
+ setattr(complete_output, "mask", masks)
+ setattr(complete_output, "num_calls", num_calls)
+ setattr(complete_output, "timeout", timeout)
+ setattr(complete_output, "tool_error", tool_error)
+ setattr(complete_output, "tool_output", tool_output)
+ setattr(complete_output, "tool_runtime", tool_runtime)
+ setattr(complete_output, "tool_called", tool_called)
+
+ actor.active_tasks.pop(sub_request_id, None)
+
+ actor.completion_queue.put(
+ {
+ "base_request_id": base_request_id,
+ "expected_n": actor.request_metadata[base_request_id]["original_sampling_params"].n,
+ "request_output": vllm.RequestOutput(
+ request_id=sub_request_id,
+ prompt=request_output.prompt,
+ prompt_token_ids=final_prompt_token_ids,
+ prompt_logprobs=request_output.prompt_logprobs,
+ outputs=[complete_output],
+ finished=True,
+ ),
+ "tools": actor.tools,
+ }
+ )
+
+
# Edited from: https://github.com/OpenRLHF/OpenRLHF/pull/971/files
# Turns out Ray doesnt necessarily place bundles together,
# so this function is used to get the bundle indices of a placement group
@@ -101,36 +255,22 @@ def get_bundle_indices_list(placement_group: ray.util.placement_group) -> List[i
return flattened_bundle_indices
-def _init_tool_tracking():
- """Initialize tracking variables for tool mode."""
- return {
- "num_calls": defaultdict(int),
- "timeout": defaultdict(bool),
- "tool_error": defaultdict(str),
- "tool_output": defaultdict(str),
- "tool_runtime": defaultdict(float),
- "tool_called": defaultdict(bool),
- "concat_outputs": {},
- "masks": defaultdict(list),
- "pending_tool_futures": {},
- }
-
-
def make_request_id(request: PromptRequest) -> str:
"""Generate a unique tracking key for a request."""
prefix = "eval" if request.is_eval else "train"
return f"{prefix}_{request.training_step}_{request.dataset_index}"
-def _extract_base_request_id(full_request_id: str) -> str:
- """Extract base request ID by removing the sample suffix.
+def split_request_id(full_request_id: str) -> dict:
+ """Split request ID into base ID and request index.
- >>> _extract_base_request_id("train_1_43039_0")
- 'train_1_43039'
- >>> _extract_base_request_id("eval_5_12345_2")
- 'eval_5_12345'
+ >>> split_request_id("train_1_43039_0")
+ {'base_id': 'train_1_43039', 'request_index': 0}
+ >>> split_request_id("eval_5_12345_2")
+ {'base_id': 'eval_5_12345', 'request_index': 2}
"""
- return "_".join(full_request_id.split("_")[:-1])
+ parts = full_request_id.split("_")
+ return {"base_id": "_".join(parts[:-1]), "request_index": int(parts[-1])}
def get_triggered_tool(
@@ -161,46 +301,12 @@ def get_triggered_tool(
return None, None
-def _handle_output(output, tools, tracking, sampling_params, max_tool_calls, executor):
- """
- Handle a finished output. Returns the output if it should be added to results,
- or None if it's being held for tool processing.
-
- This is a free function to keep the processing logic separate from the actor state.
- """
- if not tools:
- return output
-
- assert len(output.outputs) <= 1, f"{len(output.outputs)=}" # In tool mode, sampling_params.n == 1
- o = output.outputs[0]
-
- if output.request_id in tracking["concat_outputs"]:
- stored = tracking["concat_outputs"][output.request_id].outputs[0]
- stored.token_ids.extend(o.token_ids)
- stored.logprobs.extend(o.logprobs)
- else:
- tracking["concat_outputs"][output.request_id] = output
-
- tracking["masks"][output.request_id].extend([1] * len(o.token_ids))
-
- tool, stop_str = get_triggered_tool(
- o.text, tools, max_tool_calls, tracking["num_calls"][output.request_id], sampling_params
- )
- if tool is None:
- return output
-
- future = executor.submit(tool, o.text)
- tracking["pending_tool_futures"][output.request_id] = (future, o, output)
- return None
-
-
-def process_completed_request(request_id, outs, tracking, current_time, tools, request_metadata):
+def process_completed_request(request_id, outs, current_time, tools, request_metadata):
"""Process a completed request with all its samples and return the result.
Args:
request_id: The base request ID
outs: List of vllm.RequestOutput objects for all sub-requests
- tracking: Dictionary containing tool tracking information
current_time: Current timestamp for performance metrics
tools: Dictionary of available tools (may be None or empty)
request_metadata: Dictionary containing metadata for all requests
@@ -270,7 +376,7 @@ def process_completed_request(request_id, outs, tracking, current_time, tools, r
dataset_index=metadata["dataset_index"],
epoch_number=metadata["epoch_number"],
token_statistics=TokenStatistics(
- num_prompt_tokens=metadata["prompt_tokens"],
+ num_prompt_tokens=len(metadata["prompt_token_ids"]),
num_response_tokens=total_generation_tokens,
generation_time=current_time - metadata["start_time"],
),
@@ -359,38 +465,43 @@ def init_process_group(
return pg
-def add_request(
- request: PromptRequest,
- llm_engine: vllm.LLMEngine,
- tools: Dict[str, Tool],
- request_metadata: dict,
- vllm_active_requests: dict,
-) -> int:
- """Add a request to the LLM engine."""
+def _prefetch_worker(actor: "LLMRayActor") -> None:
+ while True:
+ if actor._should_stop() or len(actor.active_tasks) >= actor.inference_batch_size:
+ time.sleep(DRAIN_ACTIVE_TASKS_SLEEP_S)
+ continue
+
+ request = actor.prompt_queue.get()
+ add_request(actor, request)
+
+
+def add_request(actor: "LLMRayActor", request: PromptRequest) -> None:
request_id = make_request_id(request)
+
sampling_params = request.generation_config.clone()
sampling_params.n = 1 # Use n=1 for tool processing
- request_metadata[request_id] = {
+
+ actor.request_metadata[request_id] = {
"is_eval": request.is_eval,
"dataset_index": request.dataset_index,
"epoch_number": request.epoch_number,
"training_step": request.training_step,
"sampling_params": sampling_params,
"original_sampling_params": request.generation_config,
- "prompt_tokens": len(request.prompt),
+ "prompt_token_ids": list(request.prompt),
"start_time": time.perf_counter(),
}
tokens_prompt = vllm.TokensPrompt(prompt_token_ids=request.prompt, cache_salt=request_id)
+
for j in range(request.generation_config.n):
- sub_sampling_params = sampling_params.clone() # Already has n=1
+ sub_sampling_params = sampling_params.clone()
if request.generation_config.seed is not None:
sub_sampling_params.seed = request.generation_config.seed + j
sub_request_id = f"{request_id}_{j}"
- llm_engine.add_request(sub_request_id, tokens_prompt, sub_sampling_params)
- vllm_active_requests.add(sub_request_id)
-
- return request.generation_config.n
+ actor.active_tasks[sub_request_id] = asyncio.run_coroutine_threadsafe(
+ process_request_async(actor, sub_request_id, request_id, tokens_prompt, sub_sampling_params), actor.loop
+ )
class LLMRayActor:
@@ -413,21 +524,12 @@ def __init__(
assert_threaded_actor(self)
self._init_config(tools, max_tool_calls, inference_batch_size, inflight_updates)
self._init_queues(prompt_queue, results_queue, eval_results_queue, actor_manager)
+ self._init_executor()
noset_visible_devices = kwargs.pop("noset_visible_devices")
distributed_executor_backend = kwargs.get("distributed_executor_backend")
self._setup_gpu_visibility(noset_visible_devices, distributed_executor_backend)
-
- self._setup_engine_args(args, bundle_indices, kwargs)
-
- self.tracking = _init_tool_tracking()
- self.request_outputs = {}
- self._threads_started = threading.Event()
-
- max_workers = 22 if self.tools else 2 # 2 for background threads + 20 for tool execution if tools enabled
- self.executor = futures.ThreadPoolExecutor(max_workers=max_workers)
- self._prefetch_future = self.executor.submit(self._prefetch_worker)
- self._process_future = self.executor.submit(self._process_from_queue)
+ self._setup_and_start_async_engine(args, bundle_indices, kwargs)
def _init_config(
self,
@@ -436,23 +538,30 @@ def _init_config(
inference_batch_size: Optional[int],
inflight_updates: bool,
) -> None:
- self.logger = logger_utils.setup_logger(__name__)
self.tools = tools or {}
self.max_tool_calls = max_tool_calls or {}
self.inference_batch_size = inference_batch_size
self.inflight_updates = inflight_updates
self.request_metadata = {}
- self.vllm_active_requests = set()
+ self.active_tasks = {}
+ self.request_outputs = {}
def _init_queues(self, prompt_queue, results_queue, eval_results_queue, actor_manager) -> None:
+ self.completion_queue = queue.Queue()
self.prompt_queue = prompt_queue
self.results_queue = results_queue
self.eval_results_queue = eval_results_queue
self.actor_manager = actor_manager
+ # For caching should_stop status.
self._last_should_stop_update = float("-inf")
self._should_stop_value = False
- self._should_stop_timeout_s = 5
+
+ def _init_executor(self) -> None:
+ max_workers = NUM_PREFETCH_WORKERS + (NUM_TOOL_WORKERS if self.tools else 0)
+ self.executor = futures.ThreadPoolExecutor(max_workers=max_workers)
+ self._prefetch_future = self.executor.submit(_prefetch_worker, self)
+ self._process_future = self.executor.submit(self.process_from_queue)
def _setup_gpu_visibility(self, noset_visible_devices: bool, distributed_executor_backend: str) -> None:
# a hack to make the script work.
@@ -467,22 +576,38 @@ def _setup_gpu_visibility(self, noset_visible_devices: bool, distributed_executo
# RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set.
os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0])
- def _setup_engine_args(self, args, bundle_indices, kwargs) -> None:
+ def _setup_and_start_async_engine(self, args, bundle_indices, kwargs) -> None:
num_gpus = kwargs.pop("num_gpus")
if bundle_indices is not None:
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(num_gpus)
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
logger.debug(f"creating LLM with bundle_indices={bundle_indices}")
- engine_args = vllm.EngineArgs(*args, **kwargs)
- # Log stats causes a crash in the engine at assert outputs.scheduler_stats is not None when we call step() and there is nothing to step.
+ engine_args = vllm.AsyncEngineArgs(*args, **kwargs)
engine_args.disable_log_stats = True
- # Cascade attention has known performance issues: https://github.com/vllm-project/vllm/issues/17652
engine_args.disable_cascade_attn = True
- self.llm_engine = vllm.LLMEngine.from_engine_args(engine_args)
+ init_complete = threading.Event()
+ self.loop = None
+ self.llm_engine = None
+
+ async def _init_engine():
+ running_loop = asyncio.get_running_loop()
+ assert running_loop == self.loop, f"Loop mismatch! running={running_loop}, actor.loop={self.loop}"
+ return vllm.AsyncLLMEngine.from_engine_args(engine_args, start_engine_loop=False)
+
+ def _run_loop():
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+ self.llm_engine = self.loop.run_until_complete(_init_engine())
+ init_complete.set()
+ self.loop.run_forever()
- def get_model_dims_dict(self):
+ self.loop_thread = threading.Thread(target=_run_loop, daemon=True)
+ self.loop_thread.start()
+ init_complete.wait()
+
+ def get_model_dims_dict(self) -> Dict[str, int]:
"""Get only the model dimensions as a simple dict without loading weights."""
model_config = self.llm_engine.model_config
parallel_config = self.llm_engine.vllm_config.parallel_config
@@ -501,9 +626,9 @@ def get_model_dims_dict(self):
}
def _should_stop(self) -> bool:
- if (time.perf_counter() - self._last_should_stop_update) > self._should_stop_timeout_s:
+ if (time.perf_counter() - self._last_should_stop_update) > SHOULD_STOP_TIMEOUT_S:
should_stop_ref = self.actor_manager.should_stop.remote()
- ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1)
+ ready_refs, _ = ray.wait([should_stop_ref], timeout=SHOULD_STOP_TIMEOUT_S)
if ready_refs:
self._should_stop_value = ray.get(ready_refs[0])
self._last_should_stop_update = time.perf_counter()
@@ -511,481 +636,140 @@ def _should_stop(self) -> bool:
ray.cancel(should_stop_ref)
return self._should_stop_value
- def _prefetch_worker(self, sleep_length_s: int = 1):
- """Background worker that prefetches requests until we have enough buffered."""
- self._threads_started.set()
- while True:
- if not self.inflight_updates and self._should_stop():
- time.sleep(sleep_length_s)
- continue
- current_unfinished = self.llm_engine.get_num_unfinished_requests()
- if current_unfinished >= self.inference_batch_size:
- time.sleep(sleep_length_s)
- continue
- try:
- request = self.prompt_queue.get(timeout=0.1)
- add_request(
- request,
- self.llm_engine,
- self.tools,
- request_metadata=self.request_metadata,
- vllm_active_requests=self.vllm_active_requests,
- )
- except queue.Empty:
- continue
-
- def _insert_result_to_queue(self, result, is_eval: bool):
- """Insert result into the appropriate queue with blocking put."""
- results_queue = self.eval_results_queue if is_eval else self.results_queue
- results_queue.put(result)
+ def _accumulate_sub_request(self, sub_request: dict) -> None:
+ base_request_id = sub_request["base_request_id"]
+ expected_n = sub_request["expected_n"]
- def _process_from_queue(self, timeout: float = 60.0):
- """Run generation loop using LLMEngine directly, with optional tool support.
+ if base_request_id not in self.request_outputs:
+ self.request_outputs[base_request_id] = {
+ "outputs": [],
+ "expected_n": expected_n,
+ "tools": sub_request["tools"],
+ }
- Runs continuously in a background thread, processing requests from the engine.
+ self.request_outputs[base_request_id]["outputs"].append(sub_request["request_output"])
- Returns:
- int: Number of requests processed
- """
- total_processed = 0
- iteration_count = 0
+ is_complete = len(self.request_outputs[base_request_id]["outputs"]) == expected_n
+ if is_complete:
+ self._finalize_completed_request(base_request_id)
- while True:
- iteration_count += 1
-
- # Health check: ensure prefetch worker is alive. This will raise if it has crashed.
- if self._prefetch_future.done():
- self._prefetch_future.result()
-
- self._poll_tool_futures(self.tracking, self.llm_engine.tokenizer)
- current_time = time.perf_counter()
- if self.llm_engine.has_unfinished_requests():
- for output in [o for o in self.llm_engine.step() if o.finished]:
- # Fix the index field for all sub-requests
- # When we have n>1, we create sub-requests with IDs like
- # train_3_12_0, train_3_12_1, etc. But vLLM creates CompletionOutputs with index=0
- # for all of them (since each sub-request has n=1). We need to fix this.
- # Extract the actual index from the sub-request ID
- parts = output.request_id.rsplit("_", 1)
- assert len(parts) == 2 and parts[1].isdigit(), (
- f"Wrong request id format ({output.request_id}), should be request_id _ sub_request_index"
- )
-
- # Fix the index on the CompletionOutput
- correct_index = int(parts[1])
- output.outputs = [dataclasses.replace(o, index=correct_index) for o in output.outputs]
- base_req_id = _extract_base_request_id(output.request_id)
- result = _handle_output(
- output,
- self.tools,
- self.tracking,
- self.request_metadata[base_req_id]["sampling_params"],
- self.max_tool_calls,
- self.executor,
- )
-
- # Result is None when we do more tool processing.
- if result is None:
- # Request went to tools - remove from vllm_active_requests since it's no longer in vLLM
- self.vllm_active_requests.discard(output.request_id)
- else:
- # Sub-request is done (no more tool calls)
- if output.request_id in self.tracking["concat_outputs"]:
- complete_output = self.tracking["concat_outputs"][output.request_id].outputs[0]
- else:
- complete_output = result.outputs[0]
-
- # Remove from vllm_active_requests BEFORE calling _finalize_sub_request
- # to avoid deadlock in _maybe_process_and_insert
- self.vllm_active_requests.discard(output.request_id)
- total_processed += self._finalize_sub_request(
- output.request_id, output, complete_output, current_time
- )
- if self.llm_engine.get_num_unfinished_requests() == 0:
- time.sleep(1)
-
- return total_processed
-
- def _maybe_process_and_insert(
- self,
- request_id: str,
- request_outputs: Dict[str, List[vllm.RequestOutput]],
- tracking: Dict[str, Any],
- current_time: float,
- ) -> int:
- """Check if we have N requests for request_id, process them, and insert results in queue.
-
- Returns:
- int: Number of requests processed (0 or 1).
- """
- expected_n = self.request_metadata[request_id]["original_sampling_params"].n
-
- # Check if we have the base request in request_outputs
- if request_id not in request_outputs:
- return 0
-
- available_outputs = request_outputs[request_id].outputs
- if len(available_outputs) < expected_n:
- return 0
-
- needed_ids = [f"{request_id}_{j}" for j in range(expected_n)]
- active_sub_requests = [sub_id for sub_id in needed_ids if sub_id in self.vllm_active_requests]
- if active_sub_requests:
- return 0
-
- has_pending_tools = any(sub_id in tracking.get("pending_tool_futures", {}) for sub_id in needed_ids)
- if has_pending_tools:
- return 0
-
- # At this point we have all outputs ready. Build ordered outputs for processing.
- # First organize available_outputs into a dictionary for O(1) lookup
- outputs_by_index = {o.index: o for o in available_outputs if hasattr(o, "index")}
-
- # Verify we have all required outputs before proceeding
- if len(outputs_by_index) != expected_n or any(j not in outputs_by_index for j in range(expected_n)):
- logger.warning(
- f"Incomplete or malformed outputs for {request_id}. "
- f"Expected {expected_n} samples, got indices {sorted(outputs_by_index.keys())}. Skipping."
- )
- return 0
-
- ordered_outs: List[vllm.RequestOutput] = []
- for j in range(expected_n):
- # Create a RequestOutput wrapper for each CompletionOutput
- ordered_outs.append(
- vllm.RequestOutput(
- request_id=f"{request_id}_{j}",
- prompt=request_outputs[request_id].prompt,
- prompt_token_ids=request_outputs[request_id].prompt_token_ids,
- prompt_logprobs=request_outputs[request_id].prompt_logprobs,
- outputs=[outputs_by_index[j]],
- finished=True,
- )
- )
+ def _finalize_completed_request(self, base_request_id: str) -> None:
+ outputs = self.request_outputs[base_request_id]["outputs"]
+ ordered_outs = sorted(outputs, key=lambda x: split_request_id(x.request_id)["request_index"])
- # Remove the base entry from request_outputs to prevent growth.
- request_outputs.pop(request_id, None)
+ current_time = time.perf_counter()
result, is_eval = process_completed_request(
- request_id, ordered_outs, tracking, current_time, self.tools, self.request_metadata
+ base_request_id,
+ ordered_outs,
+ current_time,
+ self.request_outputs[base_request_id]["tools"],
+ self.request_metadata,
)
- self._insert_result_to_queue(result, is_eval=is_eval)
- self._cleanup_request_data(request_id, tracking)
- return 1
-
- def _has_pending_tool_futures_for_request(self, request_id: str, tracking: Dict[str, Any]) -> bool:
- """Check if there are any pending tool futures for a given base request ID."""
- if not self.tools or not tracking["pending_tool_futures"]:
- return False
-
- # Check if any pending tool futures belong to this base request
- for req_id in tracking["pending_tool_futures"]:
- if _extract_base_request_id(req_id) == request_id:
- return True
- return False
-
- def _has_active_sub_requests_for_base_id(self, base_request_id: str) -> bool:
- """Check if there are any active sub-requests in vLLM for a given base request ID."""
- # Check if any active request IDs belong to our base request
- for req_id in self.vllm_active_requests:
- if _extract_base_request_id(req_id) == base_request_id:
- return True
- return False
-
- def _cleanup_request_data(self, request_id: str, tracking: Dict[str, Any]):
- """Clean up metadata and tracking data for a completed request."""
- # Check if there are still pending tool futures for this request
- if self._has_pending_tool_futures_for_request(request_id, tracking):
- # Don't clean up metadata yet - tool futures still need it
- return
-
- # Check if there are still active sub-requests in vLLM for this base request
- if self._has_active_sub_requests_for_base_id(request_id):
- # Don't clean up metadata yet - active requests still need it
- return
-
- # Remove request metadata only after both conditions are met:
- # 1. No pending tool futures for this request
- # 2. No active sub-requests in vLLM for this base request
- self.request_metadata.pop(request_id, None)
-
- # Clean up tracking data for all sub-requests of this request
- if self.tools:
- # Find all sub-request IDs that belong to this base request
- sub_request_ids = [
- k for k in tracking["concat_outputs"].keys() if _extract_base_request_id(k) == request_id
- ]
-
- for sub_req_id in sub_request_ids:
- # Clean up tracking dictionaries
- tracking["concat_outputs"].pop(sub_req_id, None)
- tracking["masks"].pop(sub_req_id, None)
- tracking["num_calls"].pop(sub_req_id, None)
- tracking["timeout"].pop(sub_req_id, None)
- tracking["tool_error"].pop(sub_req_id, None)
- tracking["tool_output"].pop(sub_req_id, None)
- tracking["tool_runtime"].pop(sub_req_id, None)
- tracking["tool_called"].pop(sub_req_id, None)
- # Note: pending_tool_futures should already be cleaned by _poll_tool_futures
-
- def _finalize_sub_request(self, sub_request_id, request_output_for_prompts, complete_output, current_time):
- """
- Finalize a completed sub-request by moving it to request_outputs and processing if ready.
-
- Args:
- sub_request_id: The sub-request ID (e.g., "train_1_43039_2")
- request_output_for_prompts: RequestOutput containing prompt info
- complete_output: The CompletionOutput to add
- current_time: Current timestamp for processing
-
- Returns:
- Number of processed requests (0 or 1)
- """
- base_request_id = _extract_base_request_id(sub_request_id)
-
- # Extract the sub-request index from the sub_request_id and set it on the CompletionOutput
- # This is needed to properly identify which sub-request each output belongs to.
- # MUST be done BEFORE adding to request_outputs so that
- # _maybe_process_and_insert can find the index field when checking completeness.
- if "_" in sub_request_id:
- # Extract index from sub_request_id like "train_1_43039_2" -> 2
- parts = sub_request_id.rsplit("_", 1)
- if len(parts) == 2 and parts[1].isdigit():
- # Create new CompletionOutput with corrected index
- complete_output = dataclasses.replace(complete_output, index=int(parts[1]))
-
- # If tools are enabled, attach tool metadata to the output
- if self.tools:
- # Set tool metadata attributes on the output
- setattr(
- complete_output,
- "mask",
- self.tracking["masks"].get(sub_request_id, [1] * len(complete_output.token_ids)),
- )
- setattr(complete_output, "num_calls", self.tracking["num_calls"].get(sub_request_id, 0))
- setattr(complete_output, "timeout", self.tracking["timeout"].get(sub_request_id, False))
- setattr(complete_output, "tool_error", self.tracking["tool_error"].get(sub_request_id, ""))
- setattr(complete_output, "tool_output", self.tracking["tool_output"].get(sub_request_id, ""))
- setattr(complete_output, "tool_runtime", self.tracking["tool_runtime"].get(sub_request_id, 0.0))
- setattr(complete_output, "tool_called", self.tracking["tool_called"].get(sub_request_id, False))
-
- # Initialize request_outputs entry if needed
- if base_request_id not in self.request_outputs:
- self.request_outputs[base_request_id] = vllm.RequestOutput(
- request_id=base_request_id,
- prompt=request_output_for_prompts.prompt,
- prompt_token_ids=request_output_for_prompts.prompt_token_ids,
- prompt_logprobs=request_output_for_prompts.prompt_logprobs,
- outputs=[],
- finished=True,
- )
-
- # Add the completion output (with index field already set if needed)
- self.request_outputs[base_request_id].outputs.append(complete_output)
-
- # Try to process and insert if we have all expected outputs
- processed = self._maybe_process_and_insert(base_request_id, self.request_outputs, self.tracking, current_time)
-
- return processed
-
- def _poll_tool_futures(self, tracking, tokenizer):
- """Poll and handle completed tool executions."""
- if not self.tools or not tracking["pending_tool_futures"]:
- return []
-
- dict_keys_to_delete = []
- completed_outputs = []
-
- for req_id, (future, last_o, last_output) in list(tracking["pending_tool_futures"].items()):
- if not future.done():
- continue
- # Tool future is done, process it
- tool_result = future.result() # Get the tool result
+ self.request_outputs.pop(base_request_id)
+ self.request_metadata.pop(base_request_id, None)
- # Get sampling params from request metadata for this request
- base_req_id = _extract_base_request_id(req_id)
- sampling_params = self.request_metadata[base_req_id]["sampling_params"]
+ results_queue = self.eval_results_queue if is_eval else self.results_queue
+ results_queue.put(result)
- last_prompt_token_ids = last_output.prompt_token_ids
- last_token_ids = last_o.token_ids
- tool_output_token_ids = tokenizer.encode(
- "\n", add_special_tokens=False
- )
- tracking["timeout"][req_id] = tool_result.timeout
- tracking["tool_error"][req_id] += "" if tool_result.error is None else tool_result.error
- tracking["tool_output"][req_id] += tool_result.output
- tracking["tool_runtime"][req_id] += tool_result.runtime
- tracking["tool_called"][req_id] = True
-
- # Edge case 1: clip against model context length
- prompt_and_tool_output_token = last_prompt_token_ids + last_token_ids + tool_output_token_ids
- tracking["num_calls"][req_id] += 1
- excess = len(prompt_and_tool_output_token) - self.llm_engine.model_config.max_model_len
- if excess > 0:
- tool_output_token_ids = tool_output_token_ids[:-excess]
- can_make_new_request = False
- else:
- can_make_new_request = True
-
- # Edge case 2: clip against per-request max_tokens
- remaining = sampling_params.max_tokens - len(tracking["masks"][req_id])
- if remaining <= 0:
- tool_output_token_ids = []
- elif len(tool_output_token_ids) > remaining:
- tool_output_token_ids = tool_output_token_ids[:remaining]
-
- # Extend token_ids and logprobs for tool output tokens so lengths stay aligned
- concat_out = tracking["concat_outputs"][req_id].outputs[0]
- concat_out.token_ids.extend(tool_output_token_ids)
- # use placeholder logprobs for new tokens
- # TODO: can we do something fancier here, or allow it?
- concat_out.logprobs.extend([{tid: types.SimpleNamespace(logprob=0.0)} for tid in tool_output_token_ids])
- tracking["masks"][req_id].extend([0] * len(tool_output_token_ids))
- new_sample_tokens = sampling_params.max_tokens - len(tracking["masks"][req_id])
- can_make_new_request = can_make_new_request and new_sample_tokens > 0
-
- if can_make_new_request:
- new_sampling_params = sampling_params.clone()
- new_sampling_params.max_tokens = new_sample_tokens
-
- try:
- self.llm_engine.add_request(
- req_id, vllm.TokensPrompt(prompt_token_ids=prompt_and_tool_output_token), new_sampling_params
- )
- # Track tool continuation request as active
- base_req_id = _extract_base_request_id(req_id)
- if base_req_id in self.request_metadata:
- self.vllm_active_requests.add(req_id)
-
- except Exception as e:
- # Match original ToolUseLLM behavior - just log and continue
- logger.error(f"[_poll_tool_futures] Error adding request {req_id}: {e}")
- else:
- # Can't make a new request (hit limits), finalize this sub-request
- base_req_id = _extract_base_request_id(req_id)
-
- # Log the state before finalizing
- other_pending = [
- other_id
- for other_id in tracking["pending_tool_futures"]
- if _extract_base_request_id(other_id) == base_req_id and other_id != req_id
- ]
- logger.info(
- f"[_poll_tool_futures] Finalizing {req_id} (can't continue). "
- f"Other pending tools for {base_req_id}: {other_pending}"
- )
-
- # Remove from pending_tool_futures BEFORE finalization to ensure consistent state
- # This prevents the cleanup logic from seeing this as a pending tool future
- tracking["pending_tool_futures"].pop(req_id, None)
-
- complete_output = tracking["concat_outputs"][req_id].outputs[0]
- current_time = time.perf_counter()
- self._finalize_sub_request(req_id, last_output, complete_output, current_time)
- # Don't add to dict_keys_to_delete since we already removed it
- continue
- dict_keys_to_delete.append(req_id)
-
- # Remove the futures we just processed; do NOT clean up metadata here.
- for req_id in dict_keys_to_delete:
- tracking["pending_tool_futures"].pop(req_id, None)
-
- return completed_outputs
+ def process_from_queue(self) -> None:
+ while True:
+ sub_request = self.completion_queue.get()
+ self._accumulate_sub_request(sub_request)
def init_process_group(
self,
- master_address,
- master_port,
- rank_offset,
- world_size,
- group_name,
- backend,
- use_ray=False,
- timeout_minutes=120,
- ):
- return self.llm_engine.collective_rpc(
- "init_process_group",
- args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray, timeout_minutes),
+ master_address: str,
+ master_port: int,
+ rank_offset: int,
+ world_size: int,
+ group_name: str,
+ backend: str,
+ use_ray: bool = False,
+ timeout_minutes: int = 120,
+ ) -> None:
+ future = asyncio.run_coroutine_threadsafe(
+ self.llm_engine.collective_rpc(
+ "init_process_group",
+ args=(
+ master_address,
+ master_port,
+ rank_offset,
+ world_size,
+ group_name,
+ backend,
+ use_ray,
+ timeout_minutes,
+ ),
+ ),
+ self.loop,
)
+ return future.result(timeout=timeout_minutes * 60)
+
+ def _run_async(self, coro: Awaitable[Any]) -> Any:
+ future = asyncio.run_coroutine_threadsafe(coro, self.loop)
+ return future.result()
def _prepare_weight_update(self, name: str, dtype: str) -> None:
- # First, drain all the requests when appropriate:
- while not self.inflight_updates:
- pending_tools = len(self.tracking["pending_tool_futures"])
- unfinished = self.llm_engine.get_num_unfinished_requests()
- # if a background thread is dead, raise an error.
+ # Wait for all active requests to complete.
+ while not self.inflight_updates and len(self.active_tasks) > 0:
self.check_background_threads()
+ time.sleep(DRAIN_ACTIVE_TASKS_SLEEP_S)
- if pending_tools == 0 and unfinished == 0:
- break
-
- time.sleep(WEIGHT_UPDATE_SLEEP_INTERVAL_S)
- # Then, check that the dtypes match.
expected_dtype = str(self.llm_engine.model_config.dtype)
- assert str(dtype) == expected_dtype, (
- f"Mismatched dtype for {name}: received {dtype!r}, expected {expected_dtype!r}"
- )
+ assert dtype == expected_dtype, f"Mismatched dtype for {name}: received {dtype!r}, expected {expected_dtype!r}"
def update_weight(self, name: str, dtype: str, shape: Tuple[int, ...], empty_cache: bool = False) -> None:
self._prepare_weight_update(name, dtype)
- return self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
+ return self._run_async(self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)))
def update_weight_cuda_ipc(
self, name: str, dtype: str, shape: Tuple[int, ...], ipc_handles: List[Any], empty_cache: bool = False
) -> None:
self._prepare_weight_update(name, dtype)
- return self.llm_engine.collective_rpc(
- "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)
+ return self._run_async(
+ self.llm_engine.collective_rpc(
+ "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)
+ )
)
- def reset_prefix_cache(self):
- self.llm_engine.reset_prefix_cache()
-
- def sleep(self, level=1):
- self.llm_engine.sleep(level=level)
+ def reset_prefix_cache(self) -> None:
+ return self._run_async(self.llm_engine.reset_prefix_cache())
- def wake_up(self, tags: Optional[list[str]] = None):
- self.llm_engine.wake_up(tags)
-
- def ready(self):
- self._threads_started.wait(timeout=30)
+ def ready(self) -> bool:
return True
- def check_background_threads(self):
+ def check_background_threads(self) -> None:
if self._prefetch_future.done():
self._prefetch_future.result()
if self._process_future.done():
self._process_future.result()
+ active_tasks = list(self.active_tasks.items())
+ for task_id, task in active_tasks:
+ if task.done():
+ task.result()
+ if not self.loop_thread.is_alive():
+ raise RuntimeError(
+ "vLLM engine loop thread has died. Check logs for errors in EngineCore or async engine."
+ )
- def get_kv_cache_info(self):
+ def get_kv_cache_info(self) -> int:
"""Get KV cache max concurrency from the vLLM engine."""
- kv_cache_specs = self.llm_engine.model_executor.get_kv_cache_specs()
- kv_cache_spec = kv_cache_specs[0]
-
- page_size = kv_cache_utils.get_uniform_page_size(kv_cache_spec)
+ kv_cache_specs = self._run_async(self.llm_engine.collective_rpc("get_kv_cache_spec"))
vllm_config = self.llm_engine.vllm_config
gpu_memory_utilization = vllm_config.cache_config.gpu_memory_utilization
total_gpu_memory = torch.cuda.get_device_properties(0).total_memory
available_memory = int(gpu_memory_utilization * total_gpu_memory)
- num_blocks = kv_cache_utils.get_num_blocks(vllm_config, len(kv_cache_spec), available_memory, page_size)
+ kv_cache_groups = kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_specs[0])
- per_layer_size = page_size * num_blocks
- kv_cache_tensors = [
- kv_cache_interface.KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
- for layer_name in kv_cache_spec
- ]
-
- kv_cache_groups = kv_cache_utils.get_kv_cache_groups(vllm_config, kv_cache_spec)
-
- kv_cache_config = kv_cache_interface.KVCacheConfig(
- num_blocks=num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=kv_cache_groups
- )
- max_concurrency = kv_cache_utils.get_max_concurrency_for_kv_cache_config(
- self.llm_engine.vllm_config, kv_cache_config
+ kv_cache_config = kv_cache_utils.get_kv_cache_config_from_groups(
+ vllm_config, kv_cache_groups, kv_cache_specs[0], available_memory
)
+ max_concurrency = kv_cache_utils.get_max_concurrency_for_kv_cache_config(vllm_config, kv_cache_config)
+
return int(max_concurrency)
@@ -1045,7 +829,10 @@ def create_vllm_engines(
max_tool_calls_dict = {}
vllm_engines = []
- distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray"
+ if tensor_parallel_size == 1:
+ distributed_executor_backend = "uni"
+ else:
+ distributed_executor_backend = "ray"
use_hybrid_engine = pg is not None
num_gpus = int(tensor_parallel_size == 1)
if use_hybrid_engine and tensor_parallel_size == 1 and single_gpu_mode:
@@ -1080,7 +867,6 @@ def create_vllm_engines(
num_cpus=num_gpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
- # VLLM v1 multiprocessing is required due to https://github.com/vllm-project/vllm/issues/15349
runtime_env=ray.runtime_env.RuntimeEnv(
env_vars={"VLLM_ENABLE_V1_MULTIPROCESSING": "0", "TORCH_CUDA_ARCH_LIST": get_cuda_arch_list()}
),