1+ import asyncio
2+ import atexit
13import os
4+ import threading
25from typing import Any , Dict , List , Optional , Tuple
36
47try :
1720from tensorrt_llm .logger import logger
1821
1922from .._utils import nvtx_range_debug
23+ from ..llmapi .tracer import global_tracer
24+ from ..llmapi .utils import _SyncQueue , logger_debug
2025from .executor import GenerationExecutor
2126from .postproc_worker import PostprocWorkerConfig
2227from .ray_gpu_worker import RayGPUWorker , RayWorkerWrapper
2328from .request import GenerationRequest
24- from .result import GenerationResult , RayAsyncQueue , RaySyncQueue
29+ from .result import GenerationResult
30+ from .rpc import RPCClient
31+ from .rpc .rpc_common import get_unique_ipc_addr
32+ from .utils import ErrorResponse , is_llm_response
2533
2634__all__ = [
2735 "RayExecutor" ,
@@ -76,28 +84,24 @@ def __init__(self,
7684 self .master_address = ray .util .get_node_ip_address ()
7785 self .master_port = get_free_port ()
7886
79- self .response_queue = RayAsyncQueue .options (runtime_env = {
80- "env_vars" : {
81- "TLLM_DISABLE_MPI" : "1"
82- }
83- }).remote ()
84- self .response_sync_queue = RaySyncQueue .options (runtime_env = {
85- "env_vars" : {
86- "TLLM_DISABLE_MPI" : "1"
87- }
88- }).remote ()
89- self .async_response_queue_weakref = self .create_actor_weak_ref (
90- self .response_queue )
91- self .sync_response_queue_weakref = self .create_actor_weak_ref (
92- self .response_sync_queue )
93- self .response_queue .warmup .remote ()
94- self .response_sync_queue .warmup .remote ()
87+ self .rpc_addr = get_unique_ipc_addr ()
88+ self .rpc_client = RPCClient (self .rpc_addr )
89+
90+ self ._results = {}
91+ self ._shutdown_event = threading .Event ()
92+ self .main_loop_task_obj = None
93+ self .main_loop = None
9594
9695 worker_kwargs = dict (** worker_kwargs ,
9796 postproc_worker_config = postproc_worker_config ,
98- is_llm_executor = is_llm_executor )
97+ is_llm_executor = is_llm_executor ,
98+ rpc_addr = self .rpc_addr )
9999
100100 self .create_workers (RayGPUWorker , worker_kwargs )
101+
102+ logger .info ("Setting up engine via RPC" )
103+ self .setup_engine_remote ()
104+ self .setup_mainloop ()
101105 except Exception as e :
102106 # Clean up the Ray resources early during exception
103107 self .shutdown ()
@@ -110,8 +114,103 @@ def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
110114 return ray .actor .ActorHandle ._deserialization_helper (state ,
111115 weak_ref = True )
112116
113- def use_ray_queue (self ) -> bool :
114- return True
117+ async def _generic_fetch_loop_async (self , fetch_method_name : str ,
118+ handler_method , method_name : str ):
119+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
120+ """Generic method for fetching data in a loop from RPC worker.
121+
122+ Args:
123+ fetch_method_name: Name of the RPC client method to call
124+ handler_method: The handler method to call with the fetched data
125+ method_name: Name of the method for logging
126+ """
127+ try :
128+ fetch_method = getattr (self .rpc_client , fetch_method_name )
129+ async for data in fetch_method ().remote_streaming ():
130+ if self ._shutdown_event .is_set ():
131+ return
132+ handler_method (data )
133+ except asyncio .CancelledError :
134+ logger .debug (f"{ method_name } task cancelled" )
135+ except Exception as e :
136+ logger .error (f"Error in { method_name } : { e } " )
137+ raise
138+
139+ async def _fetch_responses_loop_async (self ):
140+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
141+ await self ._generic_fetch_loop_async (
142+ fetch_method_name = "fetch_responses_loop_async" ,
143+ handler_method = self .handle_responses ,
144+ method_name = "_fetch_responses_loop_async" )
145+
146+ def setup_mainloop (self ):
147+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
148+ async def main_loop_task ():
149+ await self ._fetch_responses_loop_async ()
150+
151+ def _run_main_loop_task ():
152+ """Local method to run the main loop task."""
153+ self .main_loop = asyncio .new_event_loop ()
154+ asyncio .set_event_loop (self .main_loop )
155+
156+ self .main_loop_task_obj = self .main_loop .create_task (
157+ main_loop_task ())
158+ try :
159+ self .main_loop .run_until_complete (self .main_loop_task_obj )
160+ except asyncio .CancelledError :
161+ pass # Task cancellation is expected during shutdown
162+ finally :
163+ self .main_loop .close ()
164+
165+ self .main_loop_thread = threading .Thread (target = _run_main_loop_task ,
166+ daemon = True ,
167+ name = "ray_executor_main_loop" )
168+ self .main_loop_thread .start ()
169+ atexit .register (self .shutdown )
170+
171+ def setup_engine_remote (self ):
172+ return self .collective_rpc ("setup_engine" , non_block = False )
173+
174+ def handle_responses (self , responses : list [GenerationResult ]) -> bool :
175+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
176+ async_queues = []
177+ event_loop = None
178+
179+ def process_res (res : list ):
180+ for r in res :
181+ client_id = r .client_id
182+ nonlocal event_loop
183+ nonlocal async_queues
184+
185+ if client_id not in self ._results :
186+ logger .warning (
187+ f"Received response for unknown client_id: { client_id } " )
188+ continue
189+
190+ queue = self ._results [client_id ].queue
191+ if isinstance (queue , _SyncQueue ):
192+ queue .put_nowait (r )
193+ async_queues .append (queue )
194+ # all the loops are identical
195+ event_loop = event_loop or queue .loop
196+ else :
197+ queue .put (r )
198+
199+ if (is_llm_response (r ) and r .result .is_final ) or isinstance (
200+ r , ErrorResponse ):
201+ self ._results .pop (client_id )
202+
203+ # Handle the case where responses might not be a list of lists
204+ if responses and not isinstance (responses [0 ], list ):
205+ # If responses is a flat list, wrap it
206+ responses = [responses ]
207+
208+ for res in responses :
209+ global_tracer ().log_instant ("RPC.get" )
210+ process_res (res )
211+
212+ if async_queues :
213+ _SyncQueue .notify_many (event_loop , async_queues )
115214
116215 def create_workers (self , worker_cls , worker_kwargs ):
117216 # When set to be a fraction, it allows Ray to schedule
@@ -192,27 +291,27 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
192291 """
193292 Low-level API to the executor. Return a "future" GenerationResult
194293 which can be waited.
195- Forwards the request to the workers through the request queue .
294+ Forwards the request to the workers through RPC .
196295 """
197296 request .set_id (self ._get_next_client_id ())
198297 logprob_params = self ._get_logprob_params (request )
199298
299+ with nvtx_range_debug ("rpc_submit" ):
300+ self .rpc_client .submit (request ).remote (need_response = False )
301+
200302 result = GenerationResult (
201303 request ,
202304 background_error_handler = self ._handle_background_error ,
203305 executor = self ,
204306 disaggregated_params = request .disaggregated_params ,
205307 logprob_params = logprob_params )
206-
207- with nvtx_range_debug ("request_queue.put" ):
208- self .call_all_ray_workers ("enqueue_request" ,
209- leader_only = True ,
210- request = request ,
211- async_call = True ,
212- result_wait_queue = result .queue )
308+ self ._results [request .id ] = result
213309
214310 return result
215311
312+ def start (self ):
313+ pass
314+
216315 def report_device_ids (self ) -> list [str ]:
217316 gpu_ids = self .call_all_ray_workers ("report_device_id" ,
218317 leader_only = False ,
@@ -225,12 +324,44 @@ def abort_request(self, request_id: int) -> None:
225324 async_call = False ,
226325 request_id = request_id )
227326
327+ # TODO: Use Ray RPC to shutdown RPC server, and then close client
228328 def shutdown (self ):
229- # Release actors
230- self .response_queue = None
231- self .response_sync_queue = None
232- self .async_response_queue_weakref = None
233- self .sync_response_queue_weakref = None
329+ if self ._shutdown_event .is_set ():
330+ return
331+ self ._shutdown_event .set ()
332+ logger_debug (f"Shutting down RayExecutor (RPC mode)" , color = "yellow" )
333+
334+ # First, cancel the main loop to stop fetching responses
335+ if hasattr (self , 'main_loop' ) and self .main_loop and hasattr (
336+ self , 'main_loop_task_obj' ) and self .main_loop_task_obj :
337+ logger_debug ("Cancelling main loop task." , color = "yellow" )
338+ try :
339+ self .main_loop .call_soon_threadsafe (
340+ self .main_loop_task_obj .cancel )
341+ except Exception as e :
342+ logger_debug (f"Error cancelling main loop task: { e } " ,
343+ color = "yellow" )
344+
345+ if hasattr (self , 'main_loop_thread' ):
346+ self .main_loop_thread .join ()
347+
348+ # Then, shutdown the workers
349+ if hasattr (self , 'workers' ) and self .workers is not None :
350+ try :
351+ logger_debug ("Shutting down RPC remote" , color = "yellow" )
352+ shutdown_refs = [
353+ worker .shutdown .remote () for worker in self .workers
354+ ]
355+ # Add timeout to prevent indefinite hanging
356+ ray .get (shutdown_refs , timeout = 30.0 )
357+ except ray .exceptions .GetTimeoutError :
358+ logger .warning (
359+ "Timeout waiting for workers to shutdown after 30 seconds" )
360+ except Exception as e :
361+ logger .warning (f"Error shutting down RPC remote: { e } " )
362+
363+ if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
364+ self .rpc_client .close ()
234365
235366 self .workers = None
236367 if hasattr (self ,
0 commit comments