Skip to content

Commit 06353fb

Browse files
committed
update
1 parent 177eb88 commit 06353fb

File tree

7 files changed

+187
-82
lines changed

7 files changed

+187
-82
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import os
45
from collections import defaultdict
56
from typing import Any, Dict, List, NamedTuple
67

@@ -507,9 +508,14 @@ def report_statistics(self) -> None:
507508
f"Max Sequence Length:\t{build_cfg['max_seq_len']}\n"
508509
f"\n")
509510
else:
511+
# Check MPI vs RAY and RPC status
512+
comm_backend = "RAY" if os.environ.get(
513+
"TLLM_DISABLE_MPI") == "1" else "MPI"
514+
ray_status = "[RPC]" if os.environ.get(
515+
"TLLM_RAY_USE_RPC") == "1" else "[original]"
510516
backend_info = (
511517
"\n\n===========================================================\n"
512-
"= PYTORCH BACKEND\n"
518+
f"= PYTORCH BACKEND [{comm_backend}] {ray_status}\n"
513519
"===========================================================\n"
514520
f"Model:\t\t\t{engine['model']}\n"
515521
f"Model Path:\t\t{engine['model_path']}\n"

tensorrt_llm/executor/base_worker.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .._torch.pyexecutor.llm_request import LlmResponse
1515
from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank,
16-
nvtx_range_debug)
16+
nvtx_range_debug, ray_use_rpc)
1717
from ..bindings import executor as tllm
1818
from ..builder import ConfigEncoder, Engine, EngineConfig
1919
from ..llmapi.llm_args import BaseLlmArgs, PybindMirror
@@ -523,9 +523,8 @@ def _deduce_max_tokens(request: GenerationRequest,
523523

524524
def submit(self, request: GenerationRequest) -> GenerationResult:
525525
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
526-
# TODO Need fix. this is a good way to catch the poor error propogation issue.
527-
# e.g., now RayGPUWorker doesn't not define start() but it won't error out and appeaer as hang.
528-
# self.start()
526+
# TODO Use this to test error propogation issue with RayExecutor.
527+
self.start()
529528

530529
if self.rank != 0:
531530
raise RuntimeError(

tensorrt_llm/executor/ray_executor.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self,
8989
if self.use_rpc:
9090
self.rpc_addr = get_unique_ipc_addr()
9191
self.rpc_client = RPCClient(self.rpc_addr)
92-
print(f"RPC client created at {self.rpc_addr}")
92+
print(f"====RPC client created at {self.rpc_addr}")
9393

9494
self._results = {}
9595
self._shutdown_event = threading.Event()
@@ -144,6 +144,13 @@ def use_ray_queue(self) -> bool:
144144
async def _generic_fetch_loop_async(self, fetch_method_name: str,
145145
handler_method, method_name: str):
146146
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
147+
"""Generic method for fetching data in a loop from RPC worker.
148+
149+
Args:
150+
fetch_method_name: Name of the RPC client method to call
151+
handler_method: The handler method to call with the fetched data
152+
method_name: Name of the method for logging
153+
"""
147154
try:
148155
fetch_method = getattr(self.rpc_client, fetch_method_name)
149156
async for data in fetch_method().remote_streaming():
@@ -169,6 +176,7 @@ async def main_loop_task():
169176
await self._fetch_responses_loop_async()
170177

171178
def _run_main_loop_task():
179+
"""Local method to run the main loop task."""
172180
self.main_loop = asyncio.new_event_loop()
173181
asyncio.set_event_loop(self.main_loop)
174182

@@ -177,12 +185,13 @@ def _run_main_loop_task():
177185
try:
178186
self.main_loop.run_until_complete(self.main_loop_task_obj)
179187
except asyncio.CancelledError:
180-
pass
188+
pass # Task cancellation is expected during shutdown
181189
finally:
182190
self.main_loop.close()
183191

184192
self.main_loop_thread = threading.Thread(target=_run_main_loop_task,
185-
daemon=True)
193+
daemon=True,
194+
name="ray_executor_main_loop")
186195
self.main_loop_thread.start()
187196
atexit.register(self.shutdown)
188197

@@ -209,6 +218,7 @@ def process_res(res: list):
209218
if isinstance(queue, _SyncQueue):
210219
queue.put_nowait(r)
211220
async_queues.append(queue)
221+
# all the loops are identical
212222
event_loop = event_loop or queue.loop
213223
else:
214224
queue.put(r)
@@ -217,7 +227,9 @@ def process_res(res: list):
217227
r, ErrorResponse):
218228
self._results.pop(client_id)
219229

230+
# Handle the case where responses might not be a list of lists
220231
if responses and not isinstance(responses[0], list):
232+
# If responses is a flat list, wrap it
221233
responses = [responses]
222234

223235
for res in responses:
@@ -314,10 +326,9 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
314326
if self.use_rpc:
315327
with nvtx_range_debug("rpc_submit"):
316328
self.rpc_client.submit(request).remote(need_response=False)
317-
print(
318-
f"[RPC] RayExecutor submit done for request {request.id}")
329+
# print(
330+
# f"[RPC] RayExecutor submit done for request {request.id}")
319331

320-
# TODO. use the future return by BaseWorker submit
321332
result = GenerationResult(
322333
request,
323334
background_error_handler=self._handle_background_error,
@@ -342,6 +353,9 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
342353

343354
return result
344355

356+
def start(self):
357+
pass
358+
345359
def report_device_ids(self) -> list[str]:
346360
gpu_ids = self.call_all_ray_workers("report_device_id",
347361
leader_only=False,
@@ -354,7 +368,16 @@ def abort_request(self, request_id: int) -> None:
354368
async_call=False,
355369
request_id=request_id)
356370

371+
# TODO: Use Ray RPC to shutdown RPC server, and then close client
357372
def shutdown(self):
373+
try:
374+
self.shutdown_impl()
375+
except Exception as e:
376+
# TODO: clean up
377+
print(f"Error shutting down RayExecutor: {e}")
378+
raise e
379+
380+
def shutdown_impl(self):
358381
if self.use_rpc:
359382
if self._shutdown_event.is_set():
360383
return

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from pathlib import Path
55
from queue import Queue
66
from threading import Event
7-
from typing import Any, Optional, Type, Union
7+
from typing import Any, AsyncGenerator, Optional, Type, Union
88

99
import ray
1010
import torch
1111

12-
from .._utils import mpi_rank, ray_use_rpc
12+
from .._utils import nvtx_range_debug, ray_use_rpc
1313
from ..bindings import executor as tllm
1414
from ..builder import Engine
1515
from ..llmapi.llm_args import BaseLlmArgs
@@ -236,53 +236,63 @@ def enqueue_request(self,
236236
request: GenerationRequest,
237237
result_wait_queue: Queue | None = None) -> int:
238238
# TODO. remove this. originally we didn't have to handle all the req id dict
239+
# raise ValueError("enqueue_request should not be called.")
239240
return self._enqueue_request(request, result_wait_queue)
240241

242+
def start(self):
243+
pass
244+
241245
def submit(self, request: GenerationRequest):
242-
print(f"RayGPUWorker {self.rank} submitted request {request.id}")
243246
return super().submit(request)
244247

245-
async def fetch_responses_async(self,
246-
timeout: Optional[float] = None) -> list:
248+
def fetch_responses(self, timeout: Optional[float] = None) -> list:
247249
# TODO copied from RpcWorker, need refactoring.
248-
logger_debug(f"RayGPUWorker {mpi_rank()} is fetching responses async",
250+
logger_debug(f"RayGPUWorker {self.rank} is fetching responses",
249251
color="yellow")
250-
251-
responses = await asyncio.to_thread(self.await_responses,
252-
timeout=timeout)
253-
if self._await_response_helper:
252+
with nvtx_range_debug("RayGPUWorker.fetch_responses",
253+
color="orange",
254+
category="Worker"):
255+
# NOTE: This is a blocking call, it will wait for the responses to be available.
256+
responses = super().await_responses(timeout)
254257
self._await_response_helper.responses_handler(responses)
255258

256-
if hasattr(self,
257-
'_response_queue') and self._response_queue is not None:
258-
qsize = self._response_queue.qsize()
259-
logger_debug(f"RayGPUWorker returning {qsize} responses",
260-
color="yellow")
259+
qsize = self._response_queue.qsize()
260+
logger_debug(f"RayGPUWorker returning {qsize} responses",
261+
color="yellow")
262+
263+
all_responses = []
264+
for _ in range(qsize):
265+
# The queue contains batches of responses, so extend the list
266+
all_responses.extend(self._response_queue.get())
267+
return all_responses
261268

262-
all_responses = []
263-
for _ in range(qsize):
264-
all_responses.extend(self._response_queue.get())
265-
return all_responses
269+
async def fetch_responses_async(self,
270+
timeout: Optional[float] = None) -> list:
271+
# TODO copied from RpcWorker, need refactoring.
272+
# A really async version of fetch_responses
273+
logger_debug(f"RayGPUWorker {self.rank} is fetching responses async",
274+
color="yellow")
266275

267-
return responses if responses else []
276+
# First, await any pending responses without blocking the event loop
277+
responses = await asyncio.to_thread(self.fetch_responses,
278+
timeout=timeout)
279+
return responses
268280

269281
# for streaming performance
270-
async def fetch_responses_loop_async(self):
282+
async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]:
271283
# TODO copied from RpcWorker, need refactoring.
272-
shutdown_event = getattr(self, 'shutdown_event', Event())
273-
274-
while not shutdown_event.is_set():
284+
while not self.shutdown_event.is_set():
275285
responses = await self.fetch_responses_async()
276-
if responses:
286+
if responses: # Only yield if there are actual responses
277287
logger_debug(
278-
f"RayGPUWorker {mpi_rank()} yielding responses: {responses}",
288+
f"RayGPUWorker {self.rank} is yielding responses: {responses}",
279289
color="yellow")
280-
yield responses
290+
yield responses # batching the responses to opt IPC performance
281291
else:
292+
# Small delay to prevent busy waiting when no responses
282293
await asyncio.sleep(0)
283-
284294
logger_debug(
285-
f"RayGPUWorker {mpi_rank()} quitting fetch_responses_loop_async",
295+
f"RayGPUWorker {self.rank} quitting fetch_responses_loop_async",
286296
color="yellow")
287297

288298
def shutdown(self):

tensorrt_llm/executor/result.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tensorrt_llm import ray_stub as ray
2121

2222
from .._ray_utils import unwrap_ray_errors
23-
from .._utils import mpi_disabled, nvtx_range_debug
23+
from .._utils import mpi_disabled, nvtx_range_debug, ray_use_rpc
2424
from ..bindings import executor as tllm
2525
from ..disaggregated_params import DisaggregatedParams
2626
from ..llmapi.tracer import global_tracer
@@ -171,6 +171,8 @@ class RayAsyncQueue:
171171
"""Ray actor for async response handling."""
172172

173173
def __init__(self):
174+
if ray_use_rpc():
175+
raise ValueError("RayAsyncQueue should not be used with RPC mode")
174176
self.data = {}
175177
self.event_map = {}
176178
self.warmup_done = False
@@ -215,6 +217,8 @@ class RaySyncQueue:
215217
"""Ray actor for sync response handling."""
216218

217219
def __init__(self):
220+
if ray_use_rpc():
221+
raise ValueError("RaySyncQueue should not be used with RPC mode")
218222
self.data = {}
219223
self.event_map = {}
220224
self.semaphore = threading.Semaphore(SYNC_QUEUE_MAX_CONCURRENCY - 1)

0 commit comments

Comments
 (0)