Skip to content

Commit 9cf1a39

Browse files
committed
adopt rpc client/server in ray path
Signed-off-by: Erin Ho <[email protected]>
1 parent 50c4863 commit 9cf1a39

File tree

4 files changed

+290
-190
lines changed

4 files changed

+290
-190
lines changed

tensorrt_llm/executor/ray_executor.py

Lines changed: 164 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import asyncio
2+
import atexit
13
import os
4+
import threading
25
from typing import Any, Dict, List, Optional, Tuple
36

47
try:
@@ -17,11 +20,16 @@
1720
from tensorrt_llm.logger import logger
1821

1922
from .._utils import nvtx_range_debug
23+
from ..llmapi.tracer import global_tracer
24+
from ..llmapi.utils import _SyncQueue, logger_debug
2025
from .executor import GenerationExecutor
2126
from .postproc_worker import PostprocWorkerConfig
2227
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
2328
from .request import GenerationRequest
24-
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
29+
from .result import GenerationResult
30+
from .rpc import RPCClient
31+
from .rpc.rpc_common import get_unique_ipc_addr
32+
from .utils import ErrorResponse, is_llm_response
2533

2634
__all__ = [
2735
"RayExecutor",
@@ -76,28 +84,24 @@ def __init__(self,
7684
self.master_address = ray.util.get_node_ip_address()
7785
self.master_port = get_free_port()
7886

79-
self.response_queue = RayAsyncQueue.options(runtime_env={
80-
"env_vars": {
81-
"TLLM_DISABLE_MPI": "1"
82-
}
83-
}).remote()
84-
self.response_sync_queue = RaySyncQueue.options(runtime_env={
85-
"env_vars": {
86-
"TLLM_DISABLE_MPI": "1"
87-
}
88-
}).remote()
89-
self.async_response_queue_weakref = self.create_actor_weak_ref(
90-
self.response_queue)
91-
self.sync_response_queue_weakref = self.create_actor_weak_ref(
92-
self.response_sync_queue)
93-
self.response_queue.warmup.remote()
94-
self.response_sync_queue.warmup.remote()
87+
self.rpc_addr = get_unique_ipc_addr()
88+
self.rpc_client = RPCClient(self.rpc_addr)
89+
90+
self._results = {}
91+
self._shutdown_event = threading.Event()
92+
self.main_loop_task_obj = None
93+
self.main_loop = None
9594

9695
worker_kwargs = dict(**worker_kwargs,
9796
postproc_worker_config=postproc_worker_config,
98-
is_llm_executor=is_llm_executor)
97+
is_llm_executor=is_llm_executor,
98+
rpc_addr=self.rpc_addr)
9999

100100
self.create_workers(RayGPUWorker, worker_kwargs)
101+
102+
logger.info("Setting up engine via RPC")
103+
self.setup_engine_remote()
104+
self.setup_mainloop()
101105
except Exception as e:
102106
# Clean up the Ray resources early during exception
103107
self.shutdown()
@@ -110,8 +114,103 @@ def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
110114
return ray.actor.ActorHandle._deserialization_helper(state,
111115
weak_ref=True)
112116

113-
def use_ray_queue(self) -> bool:
114-
return True
117+
async def _generic_fetch_loop_async(self, fetch_method_name: str,
118+
handler_method, method_name: str):
119+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
120+
"""Generic method for fetching data in a loop from RPC worker.
121+
122+
Args:
123+
fetch_method_name: Name of the RPC client method to call
124+
handler_method: The handler method to call with the fetched data
125+
method_name: Name of the method for logging
126+
"""
127+
try:
128+
fetch_method = getattr(self.rpc_client, fetch_method_name)
129+
async for data in fetch_method().remote_streaming():
130+
if self._shutdown_event.is_set():
131+
return
132+
handler_method(data)
133+
except asyncio.CancelledError:
134+
logger.debug(f"{method_name} task cancelled")
135+
except Exception as e:
136+
logger.error(f"Error in {method_name}: {e}")
137+
raise
138+
139+
async def _fetch_responses_loop_async(self):
140+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
141+
await self._generic_fetch_loop_async(
142+
fetch_method_name="fetch_responses_loop_async",
143+
handler_method=self.handle_responses,
144+
method_name="_fetch_responses_loop_async")
145+
146+
def setup_mainloop(self):
147+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
148+
async def main_loop_task():
149+
await self._fetch_responses_loop_async()
150+
151+
def _run_main_loop_task():
152+
"""Local method to run the main loop task."""
153+
self.main_loop = asyncio.new_event_loop()
154+
asyncio.set_event_loop(self.main_loop)
155+
156+
self.main_loop_task_obj = self.main_loop.create_task(
157+
main_loop_task())
158+
try:
159+
self.main_loop.run_until_complete(self.main_loop_task_obj)
160+
except asyncio.CancelledError:
161+
pass # Task cancellation is expected during shutdown
162+
finally:
163+
self.main_loop.close()
164+
165+
self.main_loop_thread = threading.Thread(target=_run_main_loop_task,
166+
daemon=True,
167+
name="ray_executor_main_loop")
168+
self.main_loop_thread.start()
169+
atexit.register(self.shutdown)
170+
171+
def setup_engine_remote(self):
172+
return self.collective_rpc("setup_engine", non_block=False)
173+
174+
def handle_responses(self, responses: list[GenerationResult]) -> bool:
175+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
176+
async_queues = []
177+
event_loop = None
178+
179+
def process_res(res: list):
180+
for r in res:
181+
client_id = r.client_id
182+
nonlocal event_loop
183+
nonlocal async_queues
184+
185+
if client_id not in self._results:
186+
logger.warning(
187+
f"Received response for unknown client_id: {client_id}")
188+
continue
189+
190+
queue = self._results[client_id].queue
191+
if isinstance(queue, _SyncQueue):
192+
queue.put_nowait(r)
193+
async_queues.append(queue)
194+
# all the loops are identical
195+
event_loop = event_loop or queue.loop
196+
else:
197+
queue.put(r)
198+
199+
if (is_llm_response(r) and r.result.is_final) or isinstance(
200+
r, ErrorResponse):
201+
self._results.pop(client_id)
202+
203+
# Handle the case where responses might not be a list of lists
204+
if responses and not isinstance(responses[0], list):
205+
# If responses is a flat list, wrap it
206+
responses = [responses]
207+
208+
for res in responses:
209+
global_tracer().log_instant("RPC.get")
210+
process_res(res)
211+
212+
if async_queues:
213+
_SyncQueue.notify_many(event_loop, async_queues)
115214

116215
def create_workers(self, worker_cls, worker_kwargs):
117216
# When set to be a fraction, it allows Ray to schedule
@@ -192,27 +291,27 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
192291
"""
193292
Low-level API to the executor. Return a "future" GenerationResult
194293
which can be waited.
195-
Forwards the request to the workers through the request queue.
294+
Forwards the request to the workers through RPC.
196295
"""
197296
request.set_id(self._get_next_client_id())
198297
logprob_params = self._get_logprob_params(request)
199298

299+
with nvtx_range_debug("rpc_submit"):
300+
self.rpc_client.submit(request).remote(need_response=False)
301+
200302
result = GenerationResult(
201303
request,
202304
background_error_handler=self._handle_background_error,
203305
executor=self,
204306
disaggregated_params=request.disaggregated_params,
205307
logprob_params=logprob_params)
206-
207-
with nvtx_range_debug("request_queue.put"):
208-
self.call_all_ray_workers("enqueue_request",
209-
leader_only=True,
210-
request=request,
211-
async_call=True,
212-
result_wait_queue=result.queue)
308+
self._results[request.id] = result
213309

214310
return result
215311

312+
def start(self):
313+
pass
314+
216315
def report_device_ids(self) -> list[str]:
217316
gpu_ids = self.call_all_ray_workers("report_device_id",
218317
leader_only=False,
@@ -225,12 +324,44 @@ def abort_request(self, request_id: int) -> None:
225324
async_call=False,
226325
request_id=request_id)
227326

327+
# TODO: Use Ray RPC to shutdown RPC server, and then close client
228328
def shutdown(self):
229-
# Release actors
230-
self.response_queue = None
231-
self.response_sync_queue = None
232-
self.async_response_queue_weakref = None
233-
self.sync_response_queue_weakref = None
329+
if self._shutdown_event.is_set():
330+
return
331+
self._shutdown_event.set()
332+
logger_debug(f"Shutting down RayExecutor (RPC mode)", color="yellow")
333+
334+
# First, cancel the main loop to stop fetching responses
335+
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
336+
self, 'main_loop_task_obj') and self.main_loop_task_obj:
337+
logger_debug("Cancelling main loop task.", color="yellow")
338+
try:
339+
self.main_loop.call_soon_threadsafe(
340+
self.main_loop_task_obj.cancel)
341+
except Exception as e:
342+
logger_debug(f"Error cancelling main loop task: {e}",
343+
color="yellow")
344+
345+
if hasattr(self, 'main_loop_thread'):
346+
self.main_loop_thread.join()
347+
348+
# Then, shutdown the workers
349+
if hasattr(self, 'workers') and self.workers is not None:
350+
try:
351+
logger_debug("Shutting down RPC remote", color="yellow")
352+
shutdown_refs = [
353+
worker.shutdown.remote() for worker in self.workers
354+
]
355+
# Add timeout to prevent indefinite hanging
356+
ray.get(shutdown_refs, timeout=30.0)
357+
except ray.exceptions.GetTimeoutError:
358+
logger.warning(
359+
"Timeout waiting for workers to shutdown after 30 seconds")
360+
except Exception as e:
361+
logger.warning(f"Error shutting down RPC remote: {e}")
362+
363+
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
364+
self.rpc_client.close()
234365

235366
self.workers = None
236367
if hasattr(self,

0 commit comments

Comments
 (0)