Skip to content

Commit 096a82e

Browse files
committed
refactor
1 parent 848755f commit 096a82e

File tree

4 files changed

+292
-440
lines changed

4 files changed

+292
-440
lines changed

tensorrt_llm/executor/ray_executor.py

Lines changed: 9 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import asyncio
2-
import atexit
31
import os
4-
import threading
52
from typing import Any, Dict, List, Optional, Tuple
63

74
try:
@@ -19,24 +16,18 @@
1916
from tensorrt_llm._utils import get_free_port
2017
from tensorrt_llm.logger import logger
2118

22-
from .._utils import nvtx_range_debug
23-
from ..llmapi.tracer import global_tracer
24-
from ..llmapi.utils import _SyncQueue, logger_debug
19+
from ..llmapi.utils import logger_debug
2520
from .executor import GenerationExecutor
2621
from .postproc_worker import PostprocWorkerConfig
2722
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
28-
from .request import GenerationRequest
29-
from .result import GenerationResult
30-
from .rpc import RPCClient
31-
from .rpc.rpc_common import get_unique_ipc_addr
32-
from .utils import ErrorResponse, is_llm_response
23+
from .rpc_proxy import RpcExecutorMixin
3324

3425
__all__ = [
3526
"RayExecutor",
3627
]
3728

3829

39-
class RayExecutor(GenerationExecutor):
30+
class RayExecutor(RpcExecutorMixin, GenerationExecutor):
4031

4132
def __init__(self,
4233
worker_kwargs: Dict,
@@ -83,135 +74,22 @@ def __init__(self,
8374
self.tp_size = tp_size
8475
self.master_address = ray.util.get_node_ip_address()
8576
self.master_port = get_free_port()
86-
87-
self.rpc_addr = get_unique_ipc_addr()
88-
self.rpc_client = RPCClient(self.rpc_addr)
89-
90-
self._results = {}
91-
self._shutdown_event = threading.Event()
92-
self.main_loop_task_obj = None
93-
self.main_loop = None
77+
self.init_rpc_executor()
9478

9579
worker_kwargs = dict(**worker_kwargs,
9680
postproc_worker_config=postproc_worker_config,
9781
is_llm_executor=is_llm_executor,
9882
rpc_addr=self.rpc_addr)
99-
10083
self.create_workers(RayGPUWorker, worker_kwargs)
101-
102-
logger.info("Setting up engine via RPC")
10384
self.setup_engine_remote()
104-
self.setup_mainloop()
85+
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
86+
thread_name="ray_executor_main_loop")
10587
except Exception as e:
10688
# Clean up the Ray resources early during exception
10789
self.shutdown()
10890
logger.error(f"Failed to initialize RayExecutor: {e}")
10991
raise e
11092

111-
@staticmethod
112-
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
113-
state, _, _ = actor_handle._serialization_helper()
114-
return ray.actor.ActorHandle._deserialization_helper(state,
115-
weak_ref=True)
116-
117-
async def _generic_fetch_loop_async(self, fetch_method_name: str,
118-
handler_method, method_name: str):
119-
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
120-
"""Generic method for fetching data in a loop from RPC worker.
121-
122-
Args:
123-
fetch_method_name: Name of the RPC client method to call
124-
handler_method: The handler method to call with the fetched data
125-
method_name: Name of the method for logging
126-
"""
127-
try:
128-
fetch_method = getattr(self.rpc_client, fetch_method_name)
129-
async for data in fetch_method().remote_streaming():
130-
if self._shutdown_event.is_set():
131-
return
132-
handler_method(data)
133-
except asyncio.CancelledError:
134-
logger.debug(f"{method_name} task cancelled")
135-
except Exception as e:
136-
logger.error(f"Error in {method_name}: {e}")
137-
raise
138-
139-
async def _fetch_responses_loop_async(self):
140-
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
141-
await self._generic_fetch_loop_async(
142-
fetch_method_name="fetch_responses_loop_async",
143-
handler_method=self.handle_responses,
144-
method_name="_fetch_responses_loop_async")
145-
146-
def setup_mainloop(self):
147-
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
148-
async def main_loop_task():
149-
await self._fetch_responses_loop_async()
150-
151-
def _run_main_loop_task():
152-
"""Local method to run the main loop task."""
153-
self.main_loop = asyncio.new_event_loop()
154-
asyncio.set_event_loop(self.main_loop)
155-
156-
self.main_loop_task_obj = self.main_loop.create_task(
157-
main_loop_task())
158-
try:
159-
self.main_loop.run_until_complete(self.main_loop_task_obj)
160-
except asyncio.CancelledError:
161-
pass # Task cancellation is expected during shutdown
162-
finally:
163-
self.main_loop.close()
164-
165-
self.main_loop_thread = threading.Thread(target=_run_main_loop_task,
166-
daemon=True,
167-
name="ray_executor_main_loop")
168-
self.main_loop_thread.start()
169-
atexit.register(self.shutdown)
170-
171-
def setup_engine_remote(self):
172-
return self.collective_rpc("setup_engine", non_block=False)
173-
174-
def handle_responses(self, responses: list[GenerationResult]) -> bool:
175-
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
176-
async_queues = []
177-
event_loop = None
178-
179-
def process_res(res: list):
180-
for r in res:
181-
client_id = r.client_id
182-
nonlocal event_loop
183-
nonlocal async_queues
184-
185-
if client_id not in self._results:
186-
logger.warning(
187-
f"Received response for unknown client_id: {client_id}")
188-
continue
189-
190-
queue = self._results[client_id].queue
191-
if isinstance(queue, _SyncQueue):
192-
queue.put_nowait(r)
193-
async_queues.append(queue)
194-
# all the loops are identical
195-
event_loop = event_loop or queue.loop
196-
else:
197-
queue.put(r)
198-
199-
if (is_llm_response(r) and r.result.is_final) or isinstance(
200-
r, ErrorResponse):
201-
self._results.pop(client_id)
202-
203-
# Handle the case where responses might not be a list of lists
204-
if responses and not isinstance(responses[0], list):
205-
# If responses is a flat list, wrap it
206-
responses = [responses]
207-
208-
for res in responses:
209-
global_tracer().log_instant("RPC.get")
210-
process_res(res)
211-
212-
if async_queues:
213-
_SyncQueue.notify_many(event_loop, async_queues)
214-
21593
def create_workers(self, worker_cls, worker_kwargs):
21694
# When set to be a fraction, it allows Ray to schedule
21795
# multiple actors on a single GPU for colocate use cases.
@@ -287,31 +165,12 @@ def collective_rpc(self,
287165
**kwargs))
288166
return refs if non_block else ray.get(refs)
289167

290-
def submit(self, request: GenerationRequest) -> GenerationResult:
291-
"""
292-
Low-level API to the executor. Return a "future" GenerationResult
293-
which can be waited.
294-
Forwards the request to the workers through RPC.
295-
"""
296-
request.set_id(self._get_next_client_id())
297-
logprob_params = self._get_logprob_params(request)
298-
299-
with nvtx_range_debug("rpc_submit"):
300-
self.rpc_client.submit(request).remote(need_response=False)
301-
302-
result = GenerationResult(
303-
request,
304-
background_error_handler=self._handle_background_error,
305-
executor=self,
306-
disaggregated_params=request.disaggregated_params,
307-
logprob_params=logprob_params)
308-
self._results[request.id] = result
309-
310-
return result
311-
312168
def start(self):
313169
pass
314170

171+
def setup_engine_remote(self):
172+
return self.collective_rpc("setup_engine", non_block=False)
173+
315174
def report_device_ids(self) -> list[str]:
316175
gpu_ids = self.call_all_ray_workers("report_device_id",
317176
leader_only=False,

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 5 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import asyncio
21
import importlib
32
import os
43
from pathlib import Path
5-
from queue import Queue
6-
from threading import Event
7-
from typing import Any, AsyncGenerator, List, Optional, Type, Union
4+
from typing import Any, List, Optional, Type, Union
85

96
import ray
107
import torch
@@ -15,19 +12,16 @@
1512
release_with_tag,
1613
verify_sleep_wakeup_tags)
1714

18-
from .._utils import nvtx_range_debug
1915
from ..bindings import executor as tllm
2016
from ..builder import Engine
2117
from ..llmapi.llm_args import BaseLlmArgs
2218
from ..llmapi.tokenizer import TokenizerBase
23-
from ..llmapi.utils import logger_debug
2419
from ..sampling_params import BatchedLogitsProcessor
2520
from .base_worker import BaseWorker
2621
from .postproc_worker import PostprocWorkerConfig
2722
from .request import GenerationRequest
2823
from .result import GenerationResult
29-
from .rpc import RPCServer
30-
from .rpc_worker import RpcWorker
24+
from .rpc_worker import RpcWorkerMixin
3125

3226
__all__ = [
3327
"RayGPUWorker",
@@ -154,7 +148,7 @@ def _inject_worker_extension(
154148
return ExtendedWorker
155149

156150

157-
class RayGPUWorker(BaseWorker):
151+
class RayGPUWorker(RpcWorkerMixin, BaseWorker):
158152

159153
def __init__(
160154
self,
@@ -183,33 +177,13 @@ def __init__(
183177
llm_args=llm_args,
184178
)
185179

186-
if not self._is_pytorch_backend:
187-
raise ValueError(f"Ray GPU worker only supports PyTorch backend.")
188-
189180
self.device_id = device_id
190-
191-
# Override rank attributes using torch
192181
self.global_rank = torch.distributed.get_rank()
193182
if self.global_rank > 1:
194183
logger.set_rank(self.global_rank)
195184

196-
if rpc_addr is None:
197-
raise RuntimeError(
198-
"RPC mode enabled but no rpc_addr provided to RayGPUWorker")
199-
200-
self.shutdown_event = Event()
201-
202-
self._response_queue = Queue()
203-
self.set_result_queue(self._response_queue)
204-
205-
self.rpc_server = None
206-
if self.global_rank == 0:
207-
logger.info(f"[Rank {self.global_rank}] Creating RPC server")
208-
self.rpc_server = RPCServer(self, num_workers=RpcWorker.NUM_WORKERS)
209-
self.rpc_server.bind(rpc_addr)
210-
self.rpc_server.start()
211-
logger.info(
212-
f"[Rank {self.global_rank}] RPC server started at {rpc_addr}")
185+
self.init_rpc_worker(self.global_rank, rpc_addr)
186+
self.start_rpc_server()
213187

214188
def setup_engine(self):
215189
if torch.distributed.is_initialized(
@@ -231,60 +205,6 @@ def _get_comm_ranks_device_id(self):
231205
def start(self):
232206
pass
233207

234-
def submit(self, request: GenerationRequest):
235-
return super().submit(request)
236-
237-
def fetch_responses(self, timeout: Optional[float] = None) -> list:
238-
# TODO copied from RpcWorker, need refactoring.
239-
logger_debug(f"RayGPUWorker {self.rank} is fetching responses",
240-
color="yellow")
241-
with nvtx_range_debug("RayGPUWorker.fetch_responses",
242-
color="orange",
243-
category="Worker"):
244-
# NOTE: This is a blocking call, it will wait for the responses to be available.
245-
responses = super().await_responses(timeout)
246-
self._await_response_helper.responses_handler(responses)
247-
248-
qsize = self._response_queue.qsize()
249-
logger_debug(f"RayGPUWorker returning {qsize} responses",
250-
color="yellow")
251-
252-
all_responses = []
253-
for _ in range(qsize):
254-
# The queue contains batches of responses, so extend the list
255-
all_responses.extend(self._response_queue.get())
256-
return all_responses
257-
258-
async def fetch_responses_async(self,
259-
timeout: Optional[float] = None) -> list:
260-
# TODO copied from RpcWorker, need refactoring.
261-
# A really async version of fetch_responses
262-
logger_debug(f"RayGPUWorker {self.rank} is fetching responses async",
263-
color="yellow")
264-
265-
# First, await any pending responses without blocking the event loop
266-
responses = await asyncio.to_thread(self.fetch_responses,
267-
timeout=timeout)
268-
return responses
269-
270-
# for streaming performance
271-
async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]:
272-
# TODO copied from RpcWorker, need refactoring.
273-
while not self.shutdown_event.is_set():
274-
# Use a short timeout to allow checking shutdown_event periodically
275-
responses = await self.fetch_responses_async(timeout=0.1)
276-
if responses: # Only yield if there are actual responses
277-
logger_debug(
278-
f"RayGPUWorker {self.rank} is yielding responses: {responses}",
279-
color="yellow")
280-
yield responses # batching the responses to opt IPC performance
281-
else:
282-
# Small delay to prevent busy waiting when no responses
283-
await asyncio.sleep(0)
284-
logger_debug(
285-
f"RayGPUWorker {self.rank} quitting fetch_responses_loop_async",
286-
color="yellow")
287-
288208
def shutdown(self):
289209

290210
if self.doing_shutdown:

0 commit comments

Comments
 (0)