1818import dataclasses
1919import os
2020import queue
21+ import threading
2122import time
2223from collections import defaultdict
2324from concurrent import futures
@@ -395,10 +396,15 @@ def __init__(
395396 self ._should_stop_value = False
396397 self ._should_stop_timeout_s = 5
397398
398- self ._executor = futures .ThreadPoolExecutor (max_workers = 1 )
399- self ._prefetch_future = self ._executor .submit (self ._prefetch_worker )
399+ # Initialize instance variables before starting threads
400400 self .tracking = _init_tool_tracking ()
401401 self .request_outputs = {}
402+ self ._threads_started = threading .Event ()
403+
404+ # Start background threads
405+ self ._executor = futures .ThreadPoolExecutor (max_workers = 2 )
406+ self ._prefetch_future = self ._executor .submit (self ._prefetch_worker )
407+ self ._process_future = self ._executor .submit (self ._process_from_queue )
402408
403409 def get_model_dims_dict (self ):
404410 """Get only the model dimensions as a simple dict without loading weights."""
@@ -431,8 +437,9 @@ def _should_stop(self) -> bool:
431437
432438 def _prefetch_worker (self , sleep_length_s : int = 1 ):
433439 """Background worker that prefetches requests until we have enough buffered."""
440+ self ._threads_started .set ()
434441 while True :
435- if self ._should_stop ():
442+ if not self . inflight_updates and self ._should_stop ():
436443 time .sleep (sleep_length_s )
437444 continue
438445 current_unfinished = self .llm_engine .get_num_unfinished_requests ()
@@ -456,58 +463,18 @@ def _insert_result_to_queue(self, result, is_eval: bool):
456463 results_queue = self .eval_results_queue if is_eval else self .results_queue
457464 results_queue .put (result )
458465
459- def _should_exit (self ) -> bool :
460- """Determine if the processing loop should exit.
461-
462- Returns:
463- bool: True if the loop should exit, False otherwise.
464- """
465- # Check stop condition first (cheapest check)
466- stop_requested = self ._should_stop ()
467-
468- # Case 1: inflight_updates enabled and stop requested - exit immediately
469- if self .inflight_updates and stop_requested :
470- return True
471-
472- # Now check for pending work (only if needed)
473- if stop_requested :
474- # Need to check if we have pending work
475- pending_tools = len (self .tracking ["pending_tool_futures" ])
476- unfinished = self .llm_engine .get_num_unfinished_requests ()
477-
478- # Case 2: stop requested and no pending work - exit
479- if pending_tools == 0 and unfinished == 0 :
480- return True
481- # Otherwise, we have pending work and should continue
482- return False
483-
484- # No stop requested - check if there's any work to do
485- pending_tools = len (self .tracking ["pending_tool_futures" ])
486- unfinished = self .llm_engine .get_num_unfinished_requests ()
487-
488- # Case 3: no work left at all - exit
489- if pending_tools == 0 and unfinished == 0 :
490- return True
491-
492- # Otherwise, continue processing
493- return False
494-
495- def process_from_queue (self , timeout : float = 60.0 ):
466+ def _process_from_queue (self , timeout : float = 60.0 ):
496467 """Run generation loop using LLMEngine directly, with optional tool support.
497468
498- Runs continuously until should_stop is set, periodically adding new requests
499- and yielding control to allow weight synchronization.
469+ Runs continuously in a background thread, processing requests from the engine.
500470
501471 Returns:
502472 int: Number of requests processed
503473 """
504-
505- # Use persistent instance variables for tracking and outputs
506- # This ensures state is maintained across multiple calls
507474 total_processed = 0
508475 iteration_count = 0
509476
510- while not self . _should_exit () :
477+ while True :
511478 iteration_count += 1
512479
513480 # Health check: ensure prefetch worker is alive. This will raise if it has crashed.
@@ -558,17 +525,7 @@ def process_from_queue(self, timeout: float = 60.0):
558525 total_processed += self ._finalize_sub_request (
559526 output .request_id , output , complete_output , current_time
560527 )
561-
562- if self .verbose and iteration_count % 100 == 0 :
563- final_unfinished = self .llm_engine .get_num_unfinished_requests ()
564- pending_tools = len (self .tracking ["pending_tool_futures" ])
565- self .logger .info (
566- f"process_from_queue iteration { iteration_count } : unfinished={ final_unfinished } , pending_tools={ pending_tools } "
567- )
568-
569- # If we have only pending tools but no unfinished requests, sleep briefly
570- # to let pending tools complete before the next iteration
571- if self .llm_engine .get_num_unfinished_requests () == 0 and len (self .tracking ["pending_tool_futures" ]) > 0 :
528+ if self .llm_engine .get_num_unfinished_requests () == 0 :
572529 time .sleep (1 )
573530
574531 return total_processed
@@ -870,10 +827,22 @@ def init_process_group(
870827 args = (master_address , master_port , rank_offset , world_size , group_name , backend , use_ray , timeout_minutes ),
871828 )
872829
830+ def _maybe_drain_requests (self , sleep_s : float = 0.1 ):
831+ while not self .inflight_updates :
832+ pending_tools = len (self .tracking ["pending_tool_futures" ])
833+ unfinished = self .llm_engine .get_num_unfinished_requests ()
834+
835+ if pending_tools == 0 and unfinished == 0 :
836+ break
837+
838+ time .sleep (sleep_s )
839+
873840 def update_weight (self , name , dtype , shape , empty_cache = False ):
841+ self ._maybe_drain_requests ()
874842 return self .llm_engine .collective_rpc ("update_weight" , args = (name , dtype , shape , empty_cache ))
875843
876844 def update_weight_cuda_ipc (self , name , dtype , shape , ipc_handles , empty_cache = False ):
845+ self ._maybe_drain_requests ()
877846 return self .llm_engine .collective_rpc (
878847 "update_weight_cuda_ipc" , args = (name , dtype , shape , ipc_handles , empty_cache )
879848 )
@@ -888,8 +857,15 @@ def wake_up(self, tags: Optional[list[str]] = None):
888857 self .llm_engine .wake_up (tags )
889858
890859 def ready (self ):
860+ self ._threads_started .wait (timeout = 30 )
891861 return True
892862
863+ def check_background_threads (self ):
864+ if self ._prefetch_future .done ():
865+ self ._prefetch_future .result ()
866+ if self ._process_future .done ():
867+ self ._process_future .result ()
868+
893869 def get_kv_cache_info (self ):
894870 """Get KV cache max concurrency from the vLLM engine."""
895871 kv_cache_specs = self .llm_engine .model_executor .get_kv_cache_specs ()
@@ -954,7 +930,6 @@ def create_vllm_engines(
954930 vllm_gpu_memory_utilization : float = 0.9 ,
955931 single_gpu_mode : bool = False ,
956932 pg : Optional [ray .util .placement_group ] = None ,
957- vllm_enable_sleep = False ,
958933 tools : Optional [Dict [str , Tool ]] = None ,
959934 max_tool_calls : List [int ] = [5 ],
960935 prompt_queue = None ,
@@ -1037,7 +1012,6 @@ def create_vllm_engines(
10371012 gpu_memory_utilization = vllm_gpu_memory_utilization ,
10381013 bundle_indices = bundle_indices ,
10391014 num_gpus = 0.2 if use_hybrid_engine else 1 ,
1040- enable_sleep_mode = vllm_enable_sleep ,
10411015 noset_visible_devices = ray_noset_visible_devices (),
10421016 prompt_queue = prompt_queue ,
10431017 results_queue = results_queue ,
@@ -1053,83 +1027,6 @@ def create_vllm_engines(
10531027 )
10541028 )
10551029
1056- # Verify engines initialized successfully
1057- try :
1058- ray_get_with_progress (
1059- [engine .ready .remote () for engine in vllm_engines ], "Initializing vLLM engines" , timeout = 300
1060- )
1061- except TimeoutError as e :
1062- logger .error (f"vLLM engines failed to initialize: { e } " )
1063- # Kill partially initialized actors before raising
1064- for engine in vllm_engines :
1065- ray .kill (engine )
1066- raise RuntimeError (f"vLLM engine initialization timed out: { e } " )
1067-
1068- if vllm_enable_sleep :
1069- batch_vllm_engine_call (vllm_engines , "sleep" , rank_0_only = False )
1030+ ray_get_with_progress ([engine .ready .remote () for engine in vllm_engines ], "Initializing vLLM engines" , timeout = 300 )
10701031
10711032 return vllm_engines
1072-
1073-
1074- def batch_vllm_engine_call (engines : List [Any ], method_name : str , * args , rank_0_only : bool = True , ** kwargs ):
1075- """
1076- Batch call a method on multiple vLLM engines.
1077- Args:
1078- engines: List of vLLM engine instances
1079- method_name: Name of the method to call
1080- rank_0_only: Only execute on rank 0 if True
1081- *args: Positional arguments to pass to the method
1082- **kwargs: Keyword arguments to pass to the method
1083- Returns:
1084- List of results from ray.get() if on rank 0, None otherwise
1085- """
1086- import torch
1087-
1088- if rank_0_only and torch .distributed .get_rank () != 0 :
1089- return None
1090-
1091- refs = []
1092- for engine in engines :
1093- method = getattr (engine , method_name )
1094- refs .append (method .remote (* args , ** kwargs ))
1095-
1096- return ray .get (refs )
1097-
1098-
1099- if __name__ == "__main__" :
1100- num_engines = 1
1101- tensor_parallel_size = 1
1102- world_size = num_engines * tensor_parallel_size + 1
1103- vllm_engines = create_vllm_engines (
1104- num_engines = num_engines ,
1105- tensor_parallel_size = tensor_parallel_size ,
1106- enforce_eager = True ,
1107- pretrain = "facebook/opt-125m" ,
1108- revision = "main" ,
1109- seed = 42 ,
1110- enable_prefix_caching = False ,
1111- max_model_len = 1024 ,
1112- )
1113- llm = vllm_engines [0 ]
1114- from vllm .utils import get_ip , get_open_port
1115-
1116- master_address = get_ip ()
1117- master_port = get_open_port ()
1118- backend = "gloo"
1119-
1120- refs = [
1121- engine .init_process_group .remote (
1122- master_address , master_port , i * tensor_parallel_size + 1 , world_size , "openrlhf" , backend = backend
1123- )
1124- for i , engine in enumerate (vllm_engines )
1125- ]
1126- model_update_group = init_process_group (
1127- backend = backend ,
1128- init_method = f"tcp://{ master_address } :{ master_port } " ,
1129- world_size = world_size ,
1130- rank = 0 ,
1131- group_name = "openrlhf" ,
1132- )
1133- ray .get (refs )
1134- output = ray .get (llm .generate .remote ("San Franciso is a" ))
1135- logger .info (f"output: { output } " )
0 commit comments