1313 placement_group )
1414
1515from tensorrt_llm ._ray_utils import unwrap_ray_errors
16- from tensorrt_llm ._utils import get_free_port
16+ from tensorrt_llm ._utils import get_free_port , nvtx_range_debug , ray_use_rpc
1717from tensorrt_llm .logger import logger
1818
1919from ..llmapi .utils import logger_debug
2020from .executor import GenerationExecutor
2121from .postproc_worker import PostprocWorkerConfig
2222from .ray_gpu_worker import RayGPUWorker , RayWorkerWrapper
23+ from .request import GenerationRequest
24+ from .result import GenerationResult , RayAsyncQueue , RaySyncQueue
2325from .rpc_proxy import RpcExecutorMixin
2426
2527__all__ = [
@@ -74,18 +76,40 @@ def __init__(self,
7476 self .tp_size = tp_size
7577 self .master_address = ray .util .get_node_ip_address ()
7678 self .master_port = get_free_port ()
77- self .init_rpc_executor ()
79+ self .use_rpc = ray_use_rpc ()
7880
7981 worker_kwargs = dict (** worker_kwargs ,
8082 postproc_worker_config = postproc_worker_config ,
81- is_llm_executor = is_llm_executor ,
82- rpc_addr = self .rpc_addr )
83- self .create_workers (RayGPUWorker , worker_kwargs )
84- self .setup_engine_remote ()
85- self .setup_mainloop (tasks = [self ._fetch_responses_loop_async ],
86- thread_name = "ray_executor_main_loop" )
83+ is_llm_executor = is_llm_executor )
84+
85+ if self .use_rpc :
86+ self .init_rpc_executor ()
87+ worker_kwargs ['rpc_addr' ] = self .rpc_addr
88+ self .create_workers (RayGPUWorker , worker_kwargs )
89+ self .setup_engine_remote ()
90+ self .setup_mainloop (tasks = [self ._fetch_responses_loop_async ],
91+ thread_name = "ray_executor_main_loop" )
92+ logger .info (f"Connecting to RPC server at { self .rpc_addr } " )
93+ else :
94+ self .response_queue = RayAsyncQueue .options (runtime_env = {
95+ "env_vars" : {
96+ "TLLM_DISABLE_MPI" : "1"
97+ }
98+ }).remote ()
99+ self .response_sync_queue = RaySyncQueue .options (runtime_env = {
100+ "env_vars" : {
101+ "TLLM_DISABLE_MPI" : "1"
102+ }
103+ }).remote ()
104+ self .async_response_queue_weakref = self .create_actor_weak_ref (
105+ self .response_queue )
106+ self .sync_response_queue_weakref = self .create_actor_weak_ref (
107+ self .response_sync_queue )
108+ self .response_queue .warmup .remote ()
109+ self .response_sync_queue .warmup .remote ()
110+ self .create_workers (RayGPUWorker , worker_kwargs )
111+
87112 except Exception as e :
88- # Clean up the Ray resources early during exception
89113 self .shutdown ()
90114 logger .error (f"Failed to initialize RayExecutor: { e } " )
91115 raise e
@@ -165,6 +189,43 @@ def collective_rpc(self,
165189 ** kwargs ))
166190 return refs if non_block else ray .get (refs )
167191
192+ def submit (self , request : "GenerationRequest" ) -> "GenerationResult" :
193+ """
194+ Low-level API to the executor. Return a "future" GenerationResult
195+ which can be waited.
196+ Forwards the request to the workers through RPC or Ray queues depending on mode.
197+ """
198+ request .set_id (self ._get_next_client_id ())
199+ logprob_params = self ._get_logprob_params (request )
200+
201+ if self .use_rpc :
202+ with nvtx_range_debug ("rpc_submit" ):
203+ self .rpc_client .submit (request ).remote (need_response = False )
204+
205+ result = GenerationResult (
206+ request ,
207+ background_error_handler = self ._handle_background_error ,
208+ executor = self ,
209+ disaggregated_params = request .disaggregated_params ,
210+ logprob_params = logprob_params )
211+ self ._results [request .id ] = result
212+ else :
213+ result = GenerationResult (
214+ request ,
215+ background_error_handler = self ._handle_background_error ,
216+ executor = self ,
217+ disaggregated_params = request .disaggregated_params ,
218+ logprob_params = logprob_params )
219+
220+ with nvtx_range_debug ("request_queue.put" ):
221+ self .call_all_ray_workers ("enqueue_request" ,
222+ leader_only = True ,
223+ request = request ,
224+ async_call = True ,
225+ result_wait_queue = result .queue )
226+
227+ return result
228+
168229 def start (self ):
169230 pass
170231
@@ -177,50 +238,69 @@ def report_device_ids(self) -> list[str]:
177238 async_call = False )
178239 return sorted (gpu_ids )
179240
241+ def use_ray_queue (self ) -> bool :
242+ return not self .use_rpc
243+
180244 def abort_request (self , request_id : int ) -> None :
181245 self .call_all_ray_workers ("abort_request" ,
182246 leader_only = True ,
183247 async_call = False ,
184248 request_id = request_id )
185249
186- # TODO: Use Ray RPC to shutdown RPC server, and then close client
187250 def shutdown (self ):
188- if self ._shutdown_event .is_set ():
251+ if hasattr ( self , '_shutdown_event' ) and self ._shutdown_event .is_set ():
189252 return
190- self . _shutdown_event . set ()
191- logger_debug ( f"Shutting down RayExecutor (RPC mode)" , color = "yellow" )
253+ if hasattr ( self , ' _shutdown_event' ):
254+ self . _shutdown_event . set ( )
192255
193- # First, cancel the main loop to stop fetching responses
194- if hasattr (self , 'main_loop' ) and self .main_loop and hasattr (
195- self , 'main_loop_task_obj' ) and self .main_loop_task_obj :
196- logger_debug ("Cancelling main loop task." , color = "yellow" )
197- try :
198- self .main_loop .call_soon_threadsafe (
199- self .main_loop_task_obj .cancel )
200- except Exception as e :
201- logger_debug (f"Error cancelling main loop task: { e } " ,
202- color = "yellow" )
256+ mode_str = "RPC mode" if self .use_rpc else "Ray queue mode"
257+ logger_debug (f"Shutting down RayExecutor ({ mode_str } )" , color = "yellow" )
203258
204- if hasattr (self , 'main_loop_thread' ):
205- self .main_loop_thread .join ()
259+ if self .use_rpc :
260+ if hasattr (self , 'main_loop' ) and self .main_loop and hasattr (
261+ self , 'main_loop_task_obj' ) and self .main_loop_task_obj :
262+ logger_debug ("Cancelling main loop task." , color = "yellow" )
263+ try :
264+ self .main_loop .call_soon_threadsafe (
265+ self .main_loop_task_obj .cancel )
266+ except Exception as e :
267+ logger_debug (f"Error cancelling main loop task: { e } " ,
268+ color = "yellow" )
206269
207- # Then, shutdown the workers
208- if hasattr (self , 'workers' ) and self .workers is not None :
209- try :
210- logger_debug ("Shutting down RPC remote" , color = "yellow" )
211- shutdown_refs = [
212- worker .shutdown .remote () for worker in self .workers
213- ]
214- # Add timeout to prevent indefinite hanging
215- ray .get (shutdown_refs , timeout = 30.0 )
216- except ray .exceptions .GetTimeoutError :
217- logger .warning (
218- "Timeout waiting for workers to shutdown after 30 seconds" )
219- except Exception as e :
220- logger .warning (f"Error shutting down RPC remote: { e } " )
270+ if hasattr (self , 'main_loop_thread' ):
271+ self .main_loop_thread .join ()
221272
222- if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
223- self .rpc_client .close ()
273+ # Then, shutdown the workers
274+ if hasattr (self , 'workers' ) and self .workers is not None :
275+ try :
276+ logger_debug ("Shutting down RPC remote" , color = "yellow" )
277+ shutdown_refs = [
278+ worker .shutdown .remote () for worker in self .workers
279+ ]
280+ # Add timeout to prevent indefinite hanging
281+ ray .get (shutdown_refs , timeout = 30.0 )
282+ except ray .exceptions .GetTimeoutError :
283+ logger .warning (
284+ "Timeout waiting for workers to shutdown after 30 seconds"
285+ )
286+ except Exception as e :
287+ logger .warning (f"Error shutting down RPC remote: { e } " )
288+
289+ if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
290+ try :
291+ self .rpc_client .close ()
292+ except Exception as e :
293+ # Suppress errors during RPC client shutdown
294+ # These can occur if the client is already closed or if there are
295+ # pending operations that get cancelled during cleanup
296+ logger_debug (
297+ f"Suppressed error during RPC client close: { e } " )
298+ else :
299+ # Release actors
300+ self .response_queue = None
301+ self .response_sync_queue = None
302+ self .async_response_queue_weakref = None
303+ self .sync_response_queue_weakref = None
224304
225305 self .workers = None
226306 if hasattr (self ,
@@ -236,12 +316,6 @@ def shutdown(self):
236316 logger .debug ("Shutting down Ray cluster" )
237317 ray .shutdown ()
238318
239- @property
240- def enable_postprocess_parallel (self ) -> bool :
241- ret = super ().enable_postprocess_parallel
242- assert ret == False , "Postprocess parallel is not supported in RayExecutor"
243- return ret
244-
245319 def _get_placement_group (self ,
246320 tp_size : int ) -> Tuple [PlacementGroup , List [int ]]:
247321 """
@@ -307,3 +381,15 @@ def _get_placement_group(self,
307381 pg = placement_group (bundles , strategy = strategy )
308382
309383 return pg , bundle_indices
384+
385+ @property
386+ def enable_postprocess_parallel (self ) -> bool :
387+ ret = super ().enable_postprocess_parallel
388+ assert ret == False , "Postprocess parallel is not supported in RayExecutor"
389+ return ret
390+
391+ @staticmethod
392+ def create_actor_weak_ref (actor_handle : ray .actor .ActorHandle ):
393+ state , _ , _ = actor_handle ._serialization_helper ()
394+ return ray .actor .ActorHandle ._deserialization_helper (state ,
395+ weak_ref = True )
0 commit comments