Skip to content

Commit d884d2d

Browse files
committed
tentatively add back ray queues
Signed-off-by: Erin Ho <[email protected]>
1 parent 7ceaf5a commit d884d2d

File tree

5 files changed

+354
-109
lines changed

5 files changed

+354
-109
lines changed

tensorrt_llm/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,13 @@ def mpi_disabled() -> bool:
524524
return os.environ.get("TLLM_DISABLE_MPI") == "1"
525525

526526

527+
def ray_use_rpc() -> bool:
528+
"""True if TLLM_RAY_USE_RPC is set to "1", False otherwise.
529+
# TODO: deprecate this once Ray is fully moved to use RPC client/server.
530+
"""
531+
return os.environ.get("TLLM_RAY_USE_RPC") == "1"
532+
533+
527534
def mpi_rank():
528535
if mpi_disabled():
529536
try:

tensorrt_llm/executor/ray_executor.py

Lines changed: 133 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
placement_group)
1414

1515
from tensorrt_llm._ray_utils import unwrap_ray_errors
16-
from tensorrt_llm._utils import get_free_port
16+
from tensorrt_llm._utils import get_free_port, nvtx_range_debug, ray_use_rpc
1717
from tensorrt_llm.logger import logger
1818

1919
from ..llmapi.utils import logger_debug
2020
from .executor import GenerationExecutor
2121
from .postproc_worker import PostprocWorkerConfig
2222
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
23+
from .request import GenerationRequest
24+
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
2325
from .rpc_proxy import RpcExecutorMixin
2426

2527
__all__ = [
@@ -74,18 +76,40 @@ def __init__(self,
7476
self.tp_size = tp_size
7577
self.master_address = ray.util.get_node_ip_address()
7678
self.master_port = get_free_port()
77-
self.init_rpc_executor()
79+
self.use_rpc = ray_use_rpc()
7880

7981
worker_kwargs = dict(**worker_kwargs,
8082
postproc_worker_config=postproc_worker_config,
81-
is_llm_executor=is_llm_executor,
82-
rpc_addr=self.rpc_addr)
83-
self.create_workers(RayGPUWorker, worker_kwargs)
84-
self.setup_engine_remote()
85-
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
86-
thread_name="ray_executor_main_loop")
83+
is_llm_executor=is_llm_executor)
84+
85+
if self.use_rpc:
86+
self.init_rpc_executor()
87+
worker_kwargs['rpc_addr'] = self.rpc_addr
88+
self.create_workers(RayGPUWorker, worker_kwargs)
89+
self.setup_engine_remote()
90+
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
91+
thread_name="ray_executor_main_loop")
92+
logger.info(f"Connecting to RPC server at {self.rpc_addr}")
93+
else:
94+
self.response_queue = RayAsyncQueue.options(runtime_env={
95+
"env_vars": {
96+
"TLLM_DISABLE_MPI": "1"
97+
}
98+
}).remote()
99+
self.response_sync_queue = RaySyncQueue.options(runtime_env={
100+
"env_vars": {
101+
"TLLM_DISABLE_MPI": "1"
102+
}
103+
}).remote()
104+
self.async_response_queue_weakref = self.create_actor_weak_ref(
105+
self.response_queue)
106+
self.sync_response_queue_weakref = self.create_actor_weak_ref(
107+
self.response_sync_queue)
108+
self.response_queue.warmup.remote()
109+
self.response_sync_queue.warmup.remote()
110+
self.create_workers(RayGPUWorker, worker_kwargs)
111+
87112
except Exception as e:
88-
# Clean up the Ray resources early during exception
89113
self.shutdown()
90114
logger.error(f"Failed to initialize RayExecutor: {e}")
91115
raise e
@@ -165,6 +189,43 @@ def collective_rpc(self,
165189
**kwargs))
166190
return refs if non_block else ray.get(refs)
167191

192+
def submit(self, request: "GenerationRequest") -> "GenerationResult":
193+
"""
194+
Low-level API to the executor. Return a "future" GenerationResult
195+
which can be waited.
196+
Forwards the request to the workers through RPC or Ray queues depending on mode.
197+
"""
198+
request.set_id(self._get_next_client_id())
199+
logprob_params = self._get_logprob_params(request)
200+
201+
if self.use_rpc:
202+
with nvtx_range_debug("rpc_submit"):
203+
self.rpc_client.submit(request).remote(need_response=False)
204+
205+
result = GenerationResult(
206+
request,
207+
background_error_handler=self._handle_background_error,
208+
executor=self,
209+
disaggregated_params=request.disaggregated_params,
210+
logprob_params=logprob_params)
211+
self._results[request.id] = result
212+
else:
213+
result = GenerationResult(
214+
request,
215+
background_error_handler=self._handle_background_error,
216+
executor=self,
217+
disaggregated_params=request.disaggregated_params,
218+
logprob_params=logprob_params)
219+
220+
with nvtx_range_debug("request_queue.put"):
221+
self.call_all_ray_workers("enqueue_request",
222+
leader_only=True,
223+
request=request,
224+
async_call=True,
225+
result_wait_queue=result.queue)
226+
227+
return result
228+
168229
def start(self):
169230
pass
170231

@@ -177,50 +238,69 @@ def report_device_ids(self) -> list[str]:
177238
async_call=False)
178239
return sorted(gpu_ids)
179240

241+
def use_ray_queue(self) -> bool:
242+
return not self.use_rpc
243+
180244
def abort_request(self, request_id: int) -> None:
181245
self.call_all_ray_workers("abort_request",
182246
leader_only=True,
183247
async_call=False,
184248
request_id=request_id)
185249

186-
# TODO: Use Ray RPC to shutdown RPC server, and then close client
187250
def shutdown(self):
188-
if self._shutdown_event.is_set():
251+
if hasattr(self, '_shutdown_event') and self._shutdown_event.is_set():
189252
return
190-
self._shutdown_event.set()
191-
logger_debug(f"Shutting down RayExecutor (RPC mode)", color="yellow")
253+
if hasattr(self, '_shutdown_event'):
254+
self._shutdown_event.set()
192255

193-
# First, cancel the main loop to stop fetching responses
194-
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
195-
self, 'main_loop_task_obj') and self.main_loop_task_obj:
196-
logger_debug("Cancelling main loop task.", color="yellow")
197-
try:
198-
self.main_loop.call_soon_threadsafe(
199-
self.main_loop_task_obj.cancel)
200-
except Exception as e:
201-
logger_debug(f"Error cancelling main loop task: {e}",
202-
color="yellow")
256+
mode_str = "RPC mode" if self.use_rpc else "Ray queue mode"
257+
logger_debug(f"Shutting down RayExecutor ({mode_str})", color="yellow")
203258

204-
if hasattr(self, 'main_loop_thread'):
205-
self.main_loop_thread.join()
259+
if self.use_rpc:
260+
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
261+
self, 'main_loop_task_obj') and self.main_loop_task_obj:
262+
logger_debug("Cancelling main loop task.", color="yellow")
263+
try:
264+
self.main_loop.call_soon_threadsafe(
265+
self.main_loop_task_obj.cancel)
266+
except Exception as e:
267+
logger_debug(f"Error cancelling main loop task: {e}",
268+
color="yellow")
206269

207-
# Then, shutdown the workers
208-
if hasattr(self, 'workers') and self.workers is not None:
209-
try:
210-
logger_debug("Shutting down RPC remote", color="yellow")
211-
shutdown_refs = [
212-
worker.shutdown.remote() for worker in self.workers
213-
]
214-
# Add timeout to prevent indefinite hanging
215-
ray.get(shutdown_refs, timeout=30.0)
216-
except ray.exceptions.GetTimeoutError:
217-
logger.warning(
218-
"Timeout waiting for workers to shutdown after 30 seconds")
219-
except Exception as e:
220-
logger.warning(f"Error shutting down RPC remote: {e}")
270+
if hasattr(self, 'main_loop_thread'):
271+
self.main_loop_thread.join()
221272

222-
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
223-
self.rpc_client.close()
273+
# Then, shutdown the workers
274+
if hasattr(self, 'workers') and self.workers is not None:
275+
try:
276+
logger_debug("Shutting down RPC remote", color="yellow")
277+
shutdown_refs = [
278+
worker.shutdown.remote() for worker in self.workers
279+
]
280+
# Add timeout to prevent indefinite hanging
281+
ray.get(shutdown_refs, timeout=30.0)
282+
except ray.exceptions.GetTimeoutError:
283+
logger.warning(
284+
"Timeout waiting for workers to shutdown after 30 seconds"
285+
)
286+
except Exception as e:
287+
logger.warning(f"Error shutting down RPC remote: {e}")
288+
289+
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
290+
try:
291+
self.rpc_client.close()
292+
except Exception as e:
293+
# Suppress errors during RPC client shutdown
294+
# These can occur if the client is already closed or if there are
295+
# pending operations that get cancelled during cleanup
296+
logger_debug(
297+
f"Suppressed error during RPC client close: {e}")
298+
else:
299+
# Release actors
300+
self.response_queue = None
301+
self.response_sync_queue = None
302+
self.async_response_queue_weakref = None
303+
self.sync_response_queue_weakref = None
224304

225305
self.workers = None
226306
if hasattr(self,
@@ -236,12 +316,6 @@ def shutdown(self):
236316
logger.debug("Shutting down Ray cluster")
237317
ray.shutdown()
238318

239-
@property
240-
def enable_postprocess_parallel(self) -> bool:
241-
ret = super().enable_postprocess_parallel
242-
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
243-
return ret
244-
245319
def _get_placement_group(self,
246320
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
247321
"""
@@ -307,3 +381,15 @@ def _get_placement_group(self,
307381
pg = placement_group(bundles, strategy=strategy)
308382

309383
return pg, bundle_indices
384+
385+
@property
386+
def enable_postprocess_parallel(self) -> bool:
387+
ret = super().enable_postprocess_parallel
388+
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
389+
return ret
390+
391+
@staticmethod
392+
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
393+
state, _, _ = actor_handle._serialization_helper()
394+
return ray.actor.ActorHandle._deserialization_helper(state,
395+
weak_ref=True)

0 commit comments

Comments
 (0)