Skip to content

Commit a867f9e

Browse files
committed
update rpc, wip
1 parent 177eb88 commit a867f9e

File tree

5 files changed

+171
-75
lines changed

5 files changed

+171
-75
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import os
45
from collections import defaultdict
56
from typing import Any, Dict, List, NamedTuple
67

@@ -507,9 +508,12 @@ def report_statistics(self) -> None:
507508
f"Max Sequence Length:\t{build_cfg['max_seq_len']}\n"
508509
f"\n")
509510
else:
511+
# Check MPI vs RAY and RPC status
512+
comm_backend = "RAY" if os.environ.get("TLLM_DISABLE_MPI") == "1" else "MPI"
513+
ray_status = "[RPC]" if os.environ.get("TLLM_RAY_USE_RPC") == "1" else "[original]"
510514
backend_info = (
511515
"\n\n===========================================================\n"
512-
"= PYTORCH BACKEND\n"
516+
f"= PYTORCH BACKEND [{comm_backend}] {ray_status}\n"
513517
"===========================================================\n"
514518
f"Model:\t\t\t{engine['model']}\n"
515519
f"Model Path:\t\t{engine['model_path']}\n"

tensorrt_llm/executor/ray_executor.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self,
8989
if self.use_rpc:
9090
self.rpc_addr = get_unique_ipc_addr()
9191
self.rpc_client = RPCClient(self.rpc_addr)
92-
print(f"RPC client created at {self.rpc_addr}")
92+
# print(f"RPC client created at {self.rpc_addr}")
9393

9494
self._results = {}
9595
self._shutdown_event = threading.Event()
@@ -144,6 +144,13 @@ def use_ray_queue(self) -> bool:
144144
async def _generic_fetch_loop_async(self, fetch_method_name: str,
145145
handler_method, method_name: str):
146146
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
147+
"""Generic method for fetching data in a loop from RPC worker.
148+
149+
Args:
150+
fetch_method_name: Name of the RPC client method to call
151+
handler_method: The handler method to call with the fetched data
152+
method_name: Name of the method for logging
153+
"""
147154
try:
148155
fetch_method = getattr(self.rpc_client, fetch_method_name)
149156
async for data in fetch_method().remote_streaming():
@@ -169,6 +176,7 @@ async def main_loop_task():
169176
await self._fetch_responses_loop_async()
170177

171178
def _run_main_loop_task():
179+
"""Local method to run the main loop task."""
172180
self.main_loop = asyncio.new_event_loop()
173181
asyncio.set_event_loop(self.main_loop)
174182

@@ -177,18 +185,21 @@ def _run_main_loop_task():
177185
try:
178186
self.main_loop.run_until_complete(self.main_loop_task_obj)
179187
except asyncio.CancelledError:
180-
pass
188+
pass # Task cancellation is expected during shutdown
181189
finally:
182190
self.main_loop.close()
183191

184192
self.main_loop_thread = threading.Thread(target=_run_main_loop_task,
185-
daemon=True)
193+
daemon=True,
194+
name="ray_executor_main_loop")
186195
self.main_loop_thread.start()
187196
atexit.register(self.shutdown)
188197

189198
def setup_engine_remote(self):
190199
return self.collective_rpc("setup_engine", non_block=False)
191200

201+
# TODO: use Ray RPC to shutdown RPC server, and then close client.
202+
192203
def handle_responses(self, responses: list[GenerationResult]) -> bool:
193204
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
194205
async_queues = []
@@ -209,6 +220,7 @@ def process_res(res: list):
209220
if isinstance(queue, _SyncQueue):
210221
queue.put_nowait(r)
211222
async_queues.append(queue)
223+
# all the loops are identical
212224
event_loop = event_loop or queue.loop
213225
else:
214226
queue.put(r)
@@ -217,7 +229,9 @@ def process_res(res: list):
217229
r, ErrorResponse):
218230
self._results.pop(client_id)
219231

232+
# Handle the case where responses might not be a list of lists
220233
if responses and not isinstance(responses[0], list):
234+
# If responses is a flat list, wrap it
221235
responses = [responses]
222236

223237
for res in responses:
@@ -314,10 +328,9 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
314328
if self.use_rpc:
315329
with nvtx_range_debug("rpc_submit"):
316330
self.rpc_client.submit(request).remote(need_response=False)
317-
print(
318-
f"[RPC] RayExecutor submit done for request {request.id}")
331+
# print(
332+
# f"[RPC] RayExecutor submit done for request {request.id}")
319333

320-
# TODO. use the future return by BaseWorker submit
321334
result = GenerationResult(
322335
request,
323336
background_error_handler=self._handle_background_error,
@@ -355,6 +368,13 @@ def abort_request(self, request_id: int) -> None:
355368
request_id=request_id)
356369

357370
def shutdown(self):
371+
try:
372+
self.shutdown_impl()
373+
except Exception as e:
374+
print(f"Error shutting down RayExecutor: {e}")
375+
raise e
376+
377+
def shutdown_impl(self):
358378
if self.use_rpc:
359379
if self._shutdown_event.is_set():
360380
return

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from pathlib import Path
55
from queue import Queue
66
from threading import Event
7-
from typing import Any, Optional, Type, Union
7+
from typing import Any, AsyncGenerator, Optional, Type, Union
88

99
import ray
1010
import torch
1111

12-
from .._utils import mpi_rank, ray_use_rpc
12+
from .._utils import nvtx_range_debug, ray_use_rpc
1313
from ..bindings import executor as tllm
1414
from ..builder import Engine
1515
from ..llmapi.llm_args import BaseLlmArgs
@@ -236,53 +236,59 @@ def enqueue_request(self,
236236
request: GenerationRequest,
237237
result_wait_queue: Queue | None = None) -> int:
238238
# TODO. remove this. originally we didn't have to handle all the req id dict
239+
# raise ValueError("enqueue_request should not be called.")
239240
return self._enqueue_request(request, result_wait_queue)
240241

241242
def submit(self, request: GenerationRequest):
242-
print(f"RayGPUWorker {self.rank} submitted request {request.id}")
243243
return super().submit(request)
244244

245-
async def fetch_responses_async(self,
246-
timeout: Optional[float] = None) -> list:
245+
def fetch_responses(self, timeout: Optional[float] = None) -> list:
247246
# TODO copied from RpcWorker, need refactoring.
248-
logger_debug(f"RayGPUWorker {mpi_rank()} is fetching responses async",
247+
logger_debug(f"RayGPUWorker {self.rank} is fetching responses",
249248
color="yellow")
250-
251-
responses = await asyncio.to_thread(self.await_responses,
252-
timeout=timeout)
253-
if self._await_response_helper:
249+
with nvtx_range_debug("RayGPUWorker.fetch_responses",
250+
color="orange",
251+
category="Worker"):
252+
# NOTE: This is a blocking call, it will wait for the responses to be available.
253+
responses = super().await_responses(timeout)
254254
self._await_response_helper.responses_handler(responses)
255255

256-
if hasattr(self,
257-
'_response_queue') and self._response_queue is not None:
258-
qsize = self._response_queue.qsize()
259-
logger_debug(f"RayGPUWorker returning {qsize} responses",
260-
color="yellow")
256+
qsize = self._response_queue.qsize()
257+
logger_debug(f"RayGPUWorker returning {qsize} responses", color="yellow")
258+
259+
all_responses = []
260+
for _ in range(qsize):
261+
# The queue contains batches of responses, so extend the list
262+
all_responses.extend(self._response_queue.get())
263+
return all_responses
261264

262-
all_responses = []
263-
for _ in range(qsize):
264-
all_responses.extend(self._response_queue.get())
265-
return all_responses
265+
async def fetch_responses_async(self,
266+
timeout: Optional[float] = None) -> list:
267+
# TODO copied from RpcWorker, need refactoring.
268+
# A really async version of fetch_responses
269+
logger_debug(f"RayGPUWorker {self.rank} is fetching responses async",
270+
color="yellow")
266271

267-
return responses if responses else []
272+
# First, await any pending responses without blocking the event loop
273+
responses = await asyncio.to_thread(self.fetch_responses,
274+
timeout=timeout)
275+
return responses
268276

269277
# for streaming performance
270-
async def fetch_responses_loop_async(self):
278+
async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]:
271279
# TODO copied from RpcWorker, need refactoring.
272-
shutdown_event = getattr(self, 'shutdown_event', Event())
273-
274-
while not shutdown_event.is_set():
280+
while not self.shutdown_event.is_set():
275281
responses = await self.fetch_responses_async()
276-
if responses:
282+
if responses: # Only yield if there are actual responses
277283
logger_debug(
278-
f"RayGPUWorker {mpi_rank()} yielding responses: {responses}",
284+
f"RayGPUWorker {self.rank} is yielding responses: {responses}",
279285
color="yellow")
280-
yield responses
286+
yield responses # batching the responses to opt IPC performance
281287
else:
288+
# Small delay to prevent busy waiting when no responses
282289
await asyncio.sleep(0)
283-
284290
logger_debug(
285-
f"RayGPUWorker {mpi_rank()} quitting fetch_responses_loop_async",
291+
f"RayGPUWorker {self.rank} quitting fetch_responses_loop_async",
286292
color="yellow")
287293

288294
def shutdown(self):

0 commit comments

Comments
 (0)