Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tensorrt_llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ def mpi_disabled() -> bool:
return os.environ.get("TLLM_DISABLE_MPI") == "1"


def ray_use_rpc() -> bool:
"""True if TLLM_RAY_USE_RPC is set to "1", False otherwise.
# TODO: deprecate this once Ray is fully moved to use RPC client/server.
"""
return os.environ.get("TLLM_RAY_USE_RPC") == "1"


def mpi_rank():
if mpi_disabled():
try:
Expand Down
194 changes: 135 additions & 59 deletions tensorrt_llm/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@
placement_group)

from tensorrt_llm._ray_utils import unwrap_ray_errors
from tensorrt_llm._utils import get_free_port
from tensorrt_llm._utils import get_free_port, nvtx_range_debug, ray_use_rpc
from tensorrt_llm.logger import logger

from .._utils import nvtx_range_debug
from ..llmapi.utils import logger_debug
from .executor import GenerationExecutor
from .postproc_worker import PostprocWorkerConfig
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
from .request import GenerationRequest
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
from .rpc_proxy import RpcExecutorMixin

__all__ = [
"RayExecutor",
]


class RayExecutor(GenerationExecutor):
class RayExecutor(RpcExecutorMixin, GenerationExecutor):

def __init__(self,
worker_kwargs: Dict,
Expand Down Expand Up @@ -75,44 +76,44 @@ def __init__(self,
self.tp_size = tp_size
self.master_address = ray.util.get_node_ip_address()
self.master_port = get_free_port()

self.response_queue = RayAsyncQueue.options(runtime_env={
"env_vars": {
"TLLM_DISABLE_MPI": "1"
}
}).remote()
self.response_sync_queue = RaySyncQueue.options(runtime_env={
"env_vars": {
"TLLM_DISABLE_MPI": "1"
}
}).remote()
self.async_response_queue_weakref = self.create_actor_weak_ref(
self.response_queue)
self.sync_response_queue_weakref = self.create_actor_weak_ref(
self.response_sync_queue)
self.response_queue.warmup.remote()
self.response_sync_queue.warmup.remote()
self.use_rpc = ray_use_rpc()

worker_kwargs = dict(**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)

self.create_workers(RayGPUWorker, worker_kwargs)
if self.use_rpc:
self.init_rpc_executor()
worker_kwargs['rpc_addr'] = self.rpc_addr
self.create_workers(RayGPUWorker, worker_kwargs)
self.setup_engine_remote()
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
logger.info(f"Connecting to RPC server at {self.rpc_addr}")
else:
self.response_queue = RayAsyncQueue.options(runtime_env={
"env_vars": {
"TLLM_DISABLE_MPI": "1"
}
}).remote()
self.response_sync_queue = RaySyncQueue.options(runtime_env={
"env_vars": {
"TLLM_DISABLE_MPI": "1"
}
}).remote()
self.async_response_queue_weakref = self.create_actor_weak_ref(
self.response_queue)
self.sync_response_queue_weakref = self.create_actor_weak_ref(
self.response_sync_queue)
self.response_queue.warmup.remote()
self.response_sync_queue.warmup.remote()
self.create_workers(RayGPUWorker, worker_kwargs)

except Exception as e:
# Clean up the Ray resources early during exception
self.shutdown()
logger.error(f"Failed to initialize RayExecutor: {e}")
raise e

@staticmethod
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
state, _, _ = actor_handle._serialization_helper()
return ray.actor.ActorHandle._deserialization_helper(state,
weak_ref=True)

def use_ray_queue(self) -> bool:
return True

def create_workers(self, worker_cls, worker_kwargs):
# When set to be a fraction, it allows Ray to schedule
# multiple actors on a single GPU for colocate use cases.
Expand Down Expand Up @@ -188,49 +189,118 @@ def collective_rpc(self,
**kwargs))
return refs if non_block else ray.get(refs)

def submit(self, request: GenerationRequest) -> GenerationResult:
def submit(self, request: "GenerationRequest") -> "GenerationResult":
"""
Low-level API to the executor. Return a "future" GenerationResult
which can be waited.
Forwards the request to the workers through the request queue.
Low-level API to the executor. Return a "future" GenerationResult
which can be waited.
Forwards the request to the workers through RPC or Ray queues depending on mode.
"""
request.set_id(self._get_next_client_id())
logprob_params = self._get_logprob_params(request)

result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)

with nvtx_range_debug("request_queue.put"):
self.call_all_ray_workers("enqueue_request",
leader_only=True,
request=request,
async_call=True,
result_wait_queue=result.queue)
if self.use_rpc:
with nvtx_range_debug("rpc_submit"):
self.rpc_client.submit(request).remote(need_response=False)

result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)
self._results[request.id] = result
else:
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)

with nvtx_range_debug("request_queue.put"):
self.call_all_ray_workers("enqueue_request",
leader_only=True,
request=request,
async_call=True,
result_wait_queue=result.queue)

return result

def start(self):
pass

def setup_engine_remote(self):
return self.collective_rpc("setup_engine", non_block=False)

def report_device_ids(self) -> list[str]:
gpu_ids = self.call_all_ray_workers("report_device_id",
leader_only=False,
async_call=False)
return sorted(gpu_ids)

def use_ray_queue(self) -> bool:
return not self.use_rpc

def abort_request(self, request_id: int) -> None:
self.call_all_ray_workers("abort_request",
leader_only=True,
async_call=False,
request_id=request_id)

def shutdown(self):
# Release actors
self.response_queue = None
self.response_sync_queue = None
self.async_response_queue_weakref = None
self.sync_response_queue_weakref = None
if hasattr(self, '_shutdown_event') and self._shutdown_event.is_set():
return
if hasattr(self, '_shutdown_event'):
self._shutdown_event.set()

mode_str = "RPC mode" if self.use_rpc else "Ray queue mode"
logger_debug(f"Shutting down RayExecutor ({mode_str})", color="yellow")

if self.use_rpc:
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
self, 'main_loop_task_obj') and self.main_loop_task_obj:
logger_debug("Cancelling main loop task.", color="yellow")
try:
self.main_loop.call_soon_threadsafe(
self.main_loop_task_obj.cancel)
except Exception as e:
logger_debug(f"Error cancelling main loop task: {e}",
color="yellow")

if hasattr(self, 'main_loop_thread'):
self.main_loop_thread.join()

# Then, shutdown the workers
if hasattr(self, 'workers') and self.workers is not None:
try:
logger_debug("Shutting down RPC remote", color="yellow")
shutdown_refs = [
worker.shutdown.remote() for worker in self.workers
]
# Add timeout to prevent indefinite hanging
ray.get(shutdown_refs, timeout=30.0)
except ray.exceptions.GetTimeoutError:
logger.warning(
"Timeout waiting for workers to shutdown after 30 seconds"
)
except Exception as e:
logger.warning(f"Error shutting down RPC remote: {e}")

if hasattr(self, 'rpc_client') and self.rpc_client is not None:
try:
self.rpc_client.close()
except Exception as e:
# Suppress errors during RPC client shutdown
# These can occur if the client is already closed or if there are
# pending operations that get cancelled during cleanup
logger_debug(
f"Suppressed error during RPC client close: {e}")
else:
# Release actors
self.response_queue = None
self.response_sync_queue = None
self.async_response_queue_weakref = None
self.sync_response_queue_weakref = None

self.workers = None
if hasattr(self,
Expand All @@ -246,12 +316,6 @@ def shutdown(self):
logger.debug("Shutting down Ray cluster")
ray.shutdown()

@property
def enable_postprocess_parallel(self) -> bool:
ret = super().enable_postprocess_parallel
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
return ret

def _get_placement_group(self,
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
"""
Expand Down Expand Up @@ -317,3 +381,15 @@ def _get_placement_group(self,
pg = placement_group(bundles, strategy=strategy)

return pg, bundle_indices

@property
def enable_postprocess_parallel(self) -> bool:
ret = super().enable_postprocess_parallel
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
return ret

@staticmethod
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
state, _, _ = actor_handle._serialization_helper()
return ray.actor.ActorHandle._deserialization_helper(state,
weak_ref=True)
Loading