diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 68229e4150d..189b96d8d66 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -524,6 +524,13 @@ def mpi_disabled() -> bool: return os.environ.get("TLLM_DISABLE_MPI") == "1" +def ray_use_rpc() -> bool: + """True if TLLM_RAY_USE_RPC is set to "1", False otherwise. + # TODO: deprecate this once Ray is fully moved to use RPC client/server. + """ + return os.environ.get("TLLM_RAY_USE_RPC") == "1" + + def mpi_rank(): if mpi_disabled(): try: diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index e0c810d7565..ad8b838217e 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -13,22 +13,23 @@ placement_group) from tensorrt_llm._ray_utils import unwrap_ray_errors -from tensorrt_llm._utils import get_free_port +from tensorrt_llm._utils import get_free_port, nvtx_range_debug, ray_use_rpc from tensorrt_llm.logger import logger -from .._utils import nvtx_range_debug +from ..llmapi.utils import logger_debug from .executor import GenerationExecutor from .postproc_worker import PostprocWorkerConfig from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper from .request import GenerationRequest from .result import GenerationResult, RayAsyncQueue, RaySyncQueue +from .rpc_proxy import RpcExecutorMixin __all__ = [ "RayExecutor", ] -class RayExecutor(GenerationExecutor): +class RayExecutor(RpcExecutorMixin, GenerationExecutor): def __init__(self, worker_kwargs: Dict, @@ -75,44 +76,44 @@ def __init__(self, self.tp_size = tp_size self.master_address = ray.util.get_node_ip_address() self.master_port = get_free_port() - - self.response_queue = RayAsyncQueue.options(runtime_env={ - "env_vars": { - "TLLM_DISABLE_MPI": "1" - } - }).remote() - self.response_sync_queue = RaySyncQueue.options(runtime_env={ - "env_vars": { - "TLLM_DISABLE_MPI": "1" - } - }).remote() - self.async_response_queue_weakref = self.create_actor_weak_ref( - self.response_queue) - self.sync_response_queue_weakref = self.create_actor_weak_ref( - self.response_sync_queue) - self.response_queue.warmup.remote() - self.response_sync_queue.warmup.remote() + self.use_rpc = ray_use_rpc() worker_kwargs = dict(**worker_kwargs, postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor) - self.create_workers(RayGPUWorker, worker_kwargs) + if self.use_rpc: + self.init_rpc_executor() + worker_kwargs['rpc_addr'] = self.rpc_addr + self.create_workers(RayGPUWorker, worker_kwargs) + self.setup_engine_remote() + self.setup_mainloop(tasks=[self._fetch_responses_loop_async], + thread_name="ray_executor_main_loop") + logger.info(f"Connecting to RPC server at {self.rpc_addr}") + else: + self.response_queue = RayAsyncQueue.options(runtime_env={ + "env_vars": { + "TLLM_DISABLE_MPI": "1" + } + }).remote() + self.response_sync_queue = RaySyncQueue.options(runtime_env={ + "env_vars": { + "TLLM_DISABLE_MPI": "1" + } + }).remote() + self.async_response_queue_weakref = self.create_actor_weak_ref( + self.response_queue) + self.sync_response_queue_weakref = self.create_actor_weak_ref( + self.response_sync_queue) + self.response_queue.warmup.remote() + self.response_sync_queue.warmup.remote() + self.create_workers(RayGPUWorker, worker_kwargs) + except Exception as e: - # Clean up the Ray resources early during exception self.shutdown() logger.error(f"Failed to initialize RayExecutor: {e}") raise e - @staticmethod - def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle): - state, _, _ = actor_handle._serialization_helper() - return ray.actor.ActorHandle._deserialization_helper(state, - weak_ref=True) - - def use_ray_queue(self) -> bool: - return True - def create_workers(self, worker_cls, worker_kwargs): # When set to be a fraction, it allows Ray to schedule # multiple actors on a single GPU for colocate use cases. @@ -188,37 +189,58 @@ def collective_rpc(self, **kwargs)) return refs if non_block else ray.get(refs) - def submit(self, request: GenerationRequest) -> GenerationResult: + def submit(self, request: "GenerationRequest") -> "GenerationResult": """ - Low-level API to the executor. Return a "future" GenerationResult - which can be waited. - Forwards the request to the workers through the request queue. + Low-level API to the executor. Return a "future" GenerationResult + which can be waited. + Forwards the request to the workers through RPC or Ray queues depending on mode. """ request.set_id(self._get_next_client_id()) logprob_params = self._get_logprob_params(request) - result = GenerationResult( - request, - background_error_handler=self._handle_background_error, - executor=self, - disaggregated_params=request.disaggregated_params, - logprob_params=logprob_params) - - with nvtx_range_debug("request_queue.put"): - self.call_all_ray_workers("enqueue_request", - leader_only=True, - request=request, - async_call=True, - result_wait_queue=result.queue) + if self.use_rpc: + with nvtx_range_debug("rpc_submit"): + self.rpc_client.submit(request).remote(need_response=False) + + result = GenerationResult( + request, + background_error_handler=self._handle_background_error, + executor=self, + disaggregated_params=request.disaggregated_params, + logprob_params=logprob_params) + self._results[request.id] = result + else: + result = GenerationResult( + request, + background_error_handler=self._handle_background_error, + executor=self, + disaggregated_params=request.disaggregated_params, + logprob_params=logprob_params) + + with nvtx_range_debug("request_queue.put"): + self.call_all_ray_workers("enqueue_request", + leader_only=True, + request=request, + async_call=True, + result_wait_queue=result.queue) return result + def start(self): + pass + + def setup_engine_remote(self): + return self.collective_rpc("setup_engine", non_block=False) + def report_device_ids(self) -> list[str]: gpu_ids = self.call_all_ray_workers("report_device_id", leader_only=False, async_call=False) return sorted(gpu_ids) + def use_ray_queue(self) -> bool: + return not self.use_rpc + def abort_request(self, request_id: int) -> None: self.call_all_ray_workers("abort_request", leader_only=True, @@ -226,11 +248,59 @@ def abort_request(self, request_id: int) -> None: request_id=request_id) def shutdown(self): - # Release actors - self.response_queue = None - self.response_sync_queue = None - self.async_response_queue_weakref = None - self.sync_response_queue_weakref = None + if hasattr(self, '_shutdown_event') and self._shutdown_event.is_set(): + return + if hasattr(self, '_shutdown_event'): + self._shutdown_event.set() + + mode_str = "RPC mode" if self.use_rpc else "Ray queue mode" + logger_debug(f"Shutting down RayExecutor ({mode_str})", color="yellow") + + if self.use_rpc: + if hasattr(self, 'main_loop') and self.main_loop and hasattr( + self, 'main_loop_task_obj') and self.main_loop_task_obj: + logger_debug("Cancelling main loop task.", color="yellow") + try: + self.main_loop.call_soon_threadsafe( + self.main_loop_task_obj.cancel) + except Exception as e: + logger_debug(f"Error cancelling main loop task: {e}", + color="yellow") + + if hasattr(self, 'main_loop_thread'): + self.main_loop_thread.join() + + # Then, shutdown the workers + if hasattr(self, 'workers') and self.workers is not None: + try: + logger_debug("Shutting down RPC remote", color="yellow") + shutdown_refs = [ + worker.shutdown.remote() for worker in self.workers + ] + # Add timeout to prevent indefinite hanging + ray.get(shutdown_refs, timeout=30.0) + except ray.exceptions.GetTimeoutError: + logger.warning( + "Timeout waiting for workers to shutdown after 30 seconds" + ) + except Exception as e: + logger.warning(f"Error shutting down RPC remote: {e}") + + if hasattr(self, 'rpc_client') and self.rpc_client is not None: + try: + self.rpc_client.close() + except Exception as e: + # Suppress errors during RPC client shutdown + # These can occur if the client is already closed or if there are + # pending operations that get cancelled during cleanup + logger_debug( + f"Suppressed error during RPC client close: {e}") + else: + # Release actors + self.response_queue = None + self.response_sync_queue = None + self.async_response_queue_weakref = None + self.sync_response_queue_weakref = None self.workers = None if hasattr(self, @@ -246,12 +316,6 @@ def shutdown(self): logger.debug("Shutting down Ray cluster") ray.shutdown() - @property - def enable_postprocess_parallel(self) -> bool: - ret = super().enable_postprocess_parallel - assert ret == False, "Postprocess parallel is not supported in RayExecutor" - return ret - def _get_placement_group(self, tp_size: int) -> Tuple[PlacementGroup, List[int]]: """ @@ -317,3 +381,15 @@ def _get_placement_group(self, pg = placement_group(bundles, strategy=strategy) return pg, bundle_indices + + @property + def enable_postprocess_parallel(self) -> bool: + ret = super().enable_postprocess_parallel + assert ret == False, "Postprocess parallel is not supported in RayExecutor" + return ret + + @staticmethod + def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle): + state, _, _ = actor_handle._serialization_helper() + return ray.actor.ActorHandle._deserialization_helper(state, + weak_ref=True) diff --git a/tensorrt_llm/executor/ray_gpu_worker.py b/tensorrt_llm/executor/ray_gpu_worker.py index 08f33cd0bca..00dc1025f4d 100644 --- a/tensorrt_llm/executor/ray_gpu_worker.py +++ b/tensorrt_llm/executor/ray_gpu_worker.py @@ -8,9 +8,11 @@ import torch from tensorrt_llm._ray_utils import control_action_decorator +from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm._torch.virtual_memory import (materialize_with_tag, release_with_tag, verify_sleep_wakeup_tags) +from tensorrt_llm._utils import ray_use_rpc from ..bindings import executor as tllm from ..builder import Engine @@ -21,6 +23,7 @@ from .postproc_worker import PostprocWorkerConfig from .request import GenerationRequest from .result import GenerationResult +from .rpc_worker import RpcWorkerMixin __all__ = [ "RayGPUWorker", @@ -83,7 +86,6 @@ def abort_request(self, request_id: int) -> None: self.worker.abort_request(request_id) def report_device_id(self) -> str: - from tensorrt_llm._torch.utils import get_device_uuid local_id = self.physical_to_local_id(self.gpu) return get_device_uuid(local_id) @@ -101,6 +103,9 @@ def call_worker_method(self, method_name: str, *args, **kwargs): raise AttributeError( f"The RayGPUWorker has no method called '{method_name}'.") + def shutdown(self): + return self.worker.shutdown() + def __repr__(self) -> str: """Customizes the actor's prefix in the Ray logs. @@ -150,7 +155,7 @@ def _inject_worker_extension( return ExtendedWorker -class RayGPUWorker(BaseWorker): +class RayGPUWorker(RpcWorkerMixin, BaseWorker): def __init__( self, @@ -163,6 +168,7 @@ def __init__( hf_model_dir: Optional[Path] = None, tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, + rpc_addr: Optional[str] = None, ) -> None: global logger from tensorrt_llm.logger import logger @@ -178,66 +184,31 @@ def __init__( llm_args=llm_args, ) - if not self._is_pytorch_backend: - raise ValueError(f"Ray GPU worker only supports PyTorch backend.") - self.device_id = device_id - - # Override rank attributes using torch self.global_rank = torch.distributed.get_rank() if self.global_rank > 1: logger.set_rank(self.global_rank) - self.setup_engine() - - def _get_comm_ranks_device_id(self): - # Make sure C++ executor would use same devices/ranks as py_executor - global_rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - comm_ranks = [None] * world_size - device_ids = [None] * world_size - - torch.distributed.all_gather_object(comm_ranks, global_rank) - torch.distributed.all_gather_object(device_ids, self.device_id) + if ray_use_rpc(): + if rpc_addr is None: + raise RuntimeError( + "RPC mode enabled but no rpc_addr provided to RayGPUWorker") + self.init_rpc_worker(self.global_rank, rpc_addr) + self.start_rpc_server() + else: + self.setup_engine() - self._configure_affinity(self.device_id) - - return comm_ranks, device_ids + def setup_engine(self): + if torch.distributed.is_initialized( + ) and torch.distributed.get_world_size() > 1: + torch.distributed.barrier() + super().setup_engine() def enqueue_request(self, request: GenerationRequest, result_wait_queue: Queue | None = None) -> int: return self._enqueue_request(request, result_wait_queue) - def submit(self, request: GenerationRequest): - raise NotImplementedError( - "Ray GPU worker does not support submit() yet.") - - def shutdown(self): - - if self.doing_shutdown: - return - else: - self.doing_shutdown = True - - logger.debug(f'Worker {self.rank} shutting down...') - - if self.engine is not None: - self.engine.shutdown() - self.engine = None - - assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined." - if (self.llm_args.backend == "pytorch" - and hasattr(self, "checkpoint_loader") - and self.checkpoint_loader is not None): - self.checkpoint_loader.cleanup() - self.checkpoint_loader = None - - # Check if there are any errors from the threads before shutdown. - self._handle_background_error() - - logger.debug(f"Worker {self.rank} shutdown done.") - @control_action_decorator def sleep(self, sleep_tags: List[str]): if not self.llm_args.enable_sleep: @@ -270,6 +241,63 @@ def wakeup(self, wakeup_tags: List[str]): logger.error(f"Encountered an error in wakeup") raise e + def start(self): + pass + + def shutdown(self): + + if self.doing_shutdown: + return + else: + self.doing_shutdown = True + + logger.debug(f'Worker {self.rank} shutting down...') + + if hasattr(self, 'shutdown_event'): + self.shutdown_event.set() + + if hasattr(self, 'rpc_server') and self.rpc_server is not None: + logger.info(f"[Rank {self.global_rank}] Shutting down RPC server") + try: + self.rpc_server.shutdown() + except Exception as e: + # Suppress errors during RPC server shutdown + # These can occur if the server is already closed or during cleanup + logger.debug( + f"[Rank {self.global_rank}] Suppressed error during RPC server shutdown: {e}" + ) + self.rpc_server = None + + if self.engine is not None: + self.engine.shutdown() + self.engine = None + + assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined." + if (self.llm_args.backend == "pytorch" + and hasattr(self, "checkpoint_loader") + and self.checkpoint_loader is not None): + self.checkpoint_loader.cleanup() + self.checkpoint_loader = None + + # Check if there are any errors from the threads before shutdown. + self._handle_background_error() + + logger.debug(f"Worker {self.rank} shutdown done.") + + def _get_comm_ranks_device_id(self): + # Make sure C++ executor would use same devices/ranks as py_executor + global_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + comm_ranks = [None] * world_size + device_ids = [None] * world_size + + torch.distributed.all_gather_object(comm_ranks, global_rank) + torch.distributed.all_gather_object(device_ids, self.device_id) + + self._configure_affinity(self.device_id) + + return comm_ranks, device_ids + def __enter__(self): return self diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 5da1dd51a0c..d47743cf8f0 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -20,7 +20,7 @@ from tensorrt_llm import ray_stub as ray from .._ray_utils import unwrap_ray_errors -from .._utils import mpi_disabled, nvtx_range_debug +from .._utils import mpi_disabled, nvtx_range_debug, ray_use_rpc from ..bindings import executor as tllm from ..disaggregated_params import DisaggregatedParams from ..llmapi.tracer import global_tracer @@ -275,7 +275,7 @@ def __init__(self, # torch backend will use trtllm sampler in beam search mode, but it does not support return logprobs incrementally self.use_trtllm_sampler = sampling_params.use_beam_search and sampling_params.best_of > 1 - if ray_queue is not None: + if ray_queue is not None and not ray_use_rpc(): if has_event_loop(): self.aqueue = ray_queue self.queue = self.aqueue @@ -557,7 +557,7 @@ def _handle_response(self, else: raise ValueError(f"Unknown response type: {response}") - if self._done and mpi_disabled(): + if self._done and mpi_disabled() and not ray_use_rpc(): assert hasattr( self.queue, "unregister" ), "Ray path should be activated for unregistering the Ray queue." @@ -790,7 +790,7 @@ def __init__( ) -> None: use_async_queue = has_event_loop() shared_queue = None - if executor and executor.use_ray_queue(): + if executor and executor.use_ray_queue() and not ray_use_rpc(): shared_queue = executor.async_response_queue_weakref if use_async_queue else executor.sync_response_queue_weakref super().__init__( @@ -855,7 +855,7 @@ def _handle_ray_response(self, response: Any): return response def _result_step(self, timeout: Optional[float] = None): - if mpi_disabled(): + if mpi_disabled() and not ray_use_rpc(): with unwrap_ray_errors(): response = ray.get(self.queue.get.remote(self.request_id)) response = self._handle_ray_response(response) @@ -866,7 +866,7 @@ def _result_step(self, timeout: Optional[float] = None): async def _aresult_step(self): assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available." - if mpi_disabled(): + if mpi_disabled() and not ray_use_rpc(): response = await self.aqueue.get_async.remote(self.request_id) response = self._handle_ray_response(response) else: diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 07657786277..d7f8636b679 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -1,6 +1,7 @@ import asyncio import concurrent.futures import threading +import time import uuid from typing import Any, AsyncIterator, Dict, Optional @@ -406,6 +407,20 @@ def _ensure_event_loop(self): if self._loop is None or not self._loop.is_running(): self._loop = asyncio.new_event_loop() + # TODO: WAR. Remove after RPC shutdown is fixed. + def custom_exception_handler(loop, context): + exception = context.get('exception') + message = context.get('message', '') + + if isinstance(exception, + asyncio.CancelledError) or "pending" in message: + logger.debug(f"Suppressed error during shutdown: {message}") + return + + loop.default_exception_handler(context) + + self._loop.set_exception_handler(custom_exception_handler) + def run_loop(): asyncio.set_event_loop(self._loop) self._loop.run_forever() @@ -416,7 +431,6 @@ def run_loop(): self._loop_thread.start() # Give the loop a moment to start - import time time.sleep(0.1) def _call_sync(self, method_name, *args, **kwargs): diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 0fb67d5baaa..61576b45bd2 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -2,7 +2,7 @@ import atexit import json import threading -from typing import Optional +from typing import Callable, List, Optional from .._utils import nvtx_range_debug from ..llmapi.mpi_session import MpiPoolSession, MpiSession @@ -20,120 +20,48 @@ get_spawn_proxy_process_env, is_llm_response) -class GenerationExecutorRpcProxy(GenerationExecutor): - # NOTE: this is a global counter for the number of instances of this class - INSTANCE_COUNTER = 0 +class RpcExecutorMixin: + """Mixin for executors that use RPC client for hot path communication. - def __init__( - self, - worker_kwargs: dict, - model_world_size: int = 1, - mpi_session: Optional[MpiSession] = None, - *, - postproc_worker_config: Optional[PostprocWorkerConfig] = None, - is_llm_executor: Optional[bool] = None, - ): - """ - Args: - worker_kwargs: kwargs for the rpc worker - model_world_size: the world size of the model - mpi_session: the mpi session to use - postproc_worker_config: the postproc worker config - is_llm_executor: whether this is an llm executor - """ - GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 - self.rpc_addr = get_unique_ipc_addr() - self.rpc_client = RPCClient(self.rpc_addr) + Provides: + - RPC client initialization + - Response handling loop + - Main loop thread management + - Shutdown logic for RPC components - postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( - ) + The inheriting class should call init_rpc_executor() to set up RPC client. + """ - super().__init__( - num_postprocess_workers=postproc_worker_config. - num_postprocess_workers, - postprocess_tokenizer_dir=postproc_worker_config. - postprocess_tokenizer_dir, - is_llm_executor=is_llm_executor, - ) + def init_rpc_executor(self): + self.rpc_addr = get_unique_ipc_addr() + self.rpc_client = RPCClient(self.rpc_addr) self._results = {} - - self._create_mpi_session(model_world_size, mpi_session) - self._shutdown_event = threading.Event() - self.worker_kwargs = worker_kwargs - self.main_loop_task_obj = None self.main_loop = None self.main_loop_thread = None - self.launch_workers() - - # Invoke model creation on the remote - # TBD: Move model creation to the mpi task, or left in RPC? - self.setup_engine_remote() - - # Setup main loop after engine is ready - self.setup_mainloop() - - def launch_workers(self): - logger.debug(f"Launching workers") - assert self.mpi_session is not None - self.mpi_session.submit(RpcWorker.main_task, - rpc_addr=self.rpc_addr, - **self.worker_kwargs) - - async def _generic_fetch_loop_async(self, fetch_method_name: str, - handler_method, method_name: str): - """Generic method for fetching data in a loop from RPC worker. + def setup_mainloop(self, + tasks: Optional[List[Callable]] = None, + thread_name: str = "rpc_proxy_main_loop"): + """Setup main loop thread with custom async tasks. Args: - fetch_method_name: Name of the RPC client method to call - handler_method: The handler method to call with the fetched data - method_name: Name of the method for logging + tasks: List of async coroutine functions to run. + thread_name: Name for the main loop thread """ - try: - fetch_method = getattr(self.rpc_client, fetch_method_name) - async for data in fetch_method().remote_streaming(): - if self._shutdown_event.is_set(): - return - handler_method(data) - except asyncio.CancelledError: - logger.debug(f"{method_name} task cancelled") - except Exception as e: - logger.error(f"Error in {method_name}: {e}") - raise - - async def _fetch_responses_loop_async(self): - await self._generic_fetch_loop_async( - fetch_method_name="fetch_responses_loop_async", - handler_method=self.handle_responses, - method_name="_fetch_responses_loop_async") - - async def _fetch_stats_loop_async(self): - await self._generic_fetch_loop_async( - fetch_method_name="fetch_stats_loop_async", - handler_method=self.handle_stats, - method_name="_fetch_stats_loop_async") - - async def _fetch_kv_cache_events_loop_async(self): - await self._generic_fetch_loop_async( - fetch_method_name="fetch_kv_cache_events_loop_async", - handler_method=self.handle_kv_cache_events, - method_name="_fetch_kv_cache_events_loop_async") - - def setup_mainloop(self): - - async def main_loop_task(): + if tasks is None: tasks = [ - self._fetch_responses_loop_async(), - self._fetch_stats_loop_async(), - self._fetch_kv_cache_events_loop_async(), + self._fetch_responses_loop_async, + self._fetch_stats_loop_async, ] # Only add kv_cache_events loop if it's enabled if self._iter_kv_events_result: - tasks.append(self._fetch_kv_cache_events_loop_async()) - await asyncio.gather(*tasks) + tasks.append(self._fetch_kv_cache_events_loop_async) + + async def main_loop_task(): + await asyncio.gather(*[task() for task in tasks]) def _run_main_loop_task(): """Local method to run the main loop task.""" @@ -151,10 +79,30 @@ def _run_main_loop_task(): self.main_loop_thread = threading.Thread(target=_run_main_loop_task, daemon=True, - name="rpc_proxy_main_loop") + name=thread_name) self.main_loop_thread.start() atexit.register(self.shutdown) + def submit(self, request: GenerationRequest) -> GenerationResult: + request.set_id(self._get_next_client_id()) + logprob_params = self._get_logprob_params(request) + + # submit is a fire-and-forget operation, don't need to wait for response + with nvtx_range_debug("RPCExecutor.submit", + color="green", + category="Proxy"): + self.rpc_client.submit(request).remote(need_response=False) + + result = GenerationResult( + request, + background_error_handler=self._handle_background_error, + executor=self, + disaggregated_params=request.disaggregated_params, + logprob_params=logprob_params) + self._results[request.id] = result + + return result + def handle_responses(self, responses: list[GenerationResult]) -> bool: async_queues = [] event_loop = None @@ -195,6 +143,63 @@ def process_res(res: list): if async_queues: _SyncQueue.notify_many(event_loop, async_queues) + def handle_stats(self, stats): + """Handle stats received from RPC worker and put them into the stats result queue. + + Args: + stats: Statistics data from the RPC worker (can be dict, str, or list) + """ + self._handle_iteration_data(stats, self._iter_stats_result, "stats") + + def handle_kv_cache_events(self, events): + """Handle KV cache events received from RPC worker and put them into the events result queue. + + Args: + events: KV cache events data from the RPC worker (can be dict, str, or list) + """ + self._handle_iteration_data(events, self._iter_kv_events_result, + "kv_cache_events") + + async def _generic_fetch_loop_async(self, fetch_method_name: str, + handler_method: Callable, + method_name: str): + """Generic method for fetching data in a loop from RPC worker. + + Args: + fetch_method_name: Name of the RPC client method to call + handler_method: The handler method to call with the fetched data + method_name: Name of the method for logging + """ + try: + fetch_method = getattr(self.rpc_client, fetch_method_name) + async for data in fetch_method().remote_streaming(): + if self._shutdown_event.is_set(): + return + handler_method(data) + except asyncio.CancelledError: + logger.debug(f"{method_name} task cancelled") + except Exception as e: + logger.error(f"Error in {method_name}: {e}") + raise + + async def _fetch_responses_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_responses_loop_async", + handler_method=self.handle_responses, + method_name="_fetch_responses_loop_async") + + async def _fetch_stats_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_stats_loop_async", + handler_method=self.handle_stats, + method_name="_fetch_stats_loop_async") + + async def _fetch_kv_cache_events_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_kv_cache_events_loop_async", + handler_method=self.handle_kv_cache_events, + method_name="_fetch_kv_cache_events_loop_async") + def _handle_iteration_data(self, data, result_singleton, data_type: str): """Generic method to handle iteration data received from RPC worker. @@ -268,42 +273,74 @@ def _handle_iteration_data(self, data, result_singleton, data_type: str): logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") raise e - def handle_stats(self, stats): - """Handle stats received from RPC worker and put them into the stats result queue. +class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor): + # NOTE: this is a global counter for the number of instances of this class + INSTANCE_COUNTER = 0 + + def __init__( + self, + worker_kwargs: dict, + model_world_size: int = 1, + mpi_session: Optional[MpiSession] = None, + *, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + is_llm_executor: Optional[bool] = None, + ): + """ Args: - stats: Statistics data from the RPC worker (can be dict, str, or list) + worker_kwargs: kwargs for the rpc worker + model_world_size: the world size of the model + mpi_session: the mpi session to use + postproc_worker_config: the postproc worker config + is_llm_executor: whether this is an llm executor """ - self._handle_iteration_data(stats, self._iter_stats_result, "stats") + GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 + self.init_rpc_executor() - def handle_kv_cache_events(self, events): - """Handle KV cache events received from RPC worker and put them into the events result queue. + postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( + ) - Args: - events: KV cache events data from the RPC worker (can be dict, str, or list) - """ - self._handle_iteration_data(events, self._iter_kv_events_result, - "kv_cache_events") + super().__init__( + num_postprocess_workers=postproc_worker_config. + num_postprocess_workers, + postprocess_tokenizer_dir=postproc_worker_config. + postprocess_tokenizer_dir, + is_llm_executor=is_llm_executor, + ) - def submit(self, request: GenerationRequest) -> GenerationResult: - request.set_id(self._get_next_client_id()) - logprob_params = self._get_logprob_params(request) + self._create_mpi_session(model_world_size, mpi_session) - # submit is a fire-and-forget operation, don't need to wait for response - with nvtx_range_debug("GenerationExecutorRpcProxy.submit", - color="green", - category="Proxy"): - self.rpc_client.submit(request).remote(need_response=False) + self.worker_kwargs = worker_kwargs - result = GenerationResult( - request, - background_error_handler=self._handle_background_error, - executor=self, - disaggregated_params=request.disaggregated_params, - logprob_params=logprob_params) - self._results[request.id] = result + self.launch_workers() - return result + # Invoke model creation on the remote + # TBD: Move model creation to the mpi task, or left in RPC? + self.setup_engine_remote() + + # Setup main loop after engine is ready + self._setup_mainloop_with_tasks() + + def launch_workers(self): + logger.debug(f"Launching workers") + assert self.mpi_session is not None + self.mpi_session.submit(RpcWorker.main_task, + rpc_addr=self.rpc_addr, + **self.worker_kwargs) + + def _setup_mainloop_with_tasks(self): + """Setup mainloop with all tasks needed for RpcProxy.""" + tasks = [ + self._fetch_responses_loop_async, + self._fetch_stats_loop_async, + ] + # Only add kv_cache_events loop if it's enabled + if self._iter_kv_events_result: + tasks.append(self._fetch_kv_cache_events_loop_async) + + # Call mixin's setup_mainloop with custom tasks + self.setup_mainloop(tasks=tasks, thread_name="rpc_proxy_main_loop") def fetch_stats_remote(self): return self.rpc_client.fetch_stats().remote() diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 47bcfacdef4..a778ba67ed1 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -22,54 +22,41 @@ from .rpc import RPCServer -class RpcWorker(BaseWorker): - """ - A RPC wrapper for the BaseWorker class. +class RpcWorkerMixin: + """Mixin for workers that serve RPC requests. - Actions: - - `setup_engine`: Setup the engine. - - `submit`: Submit a request to the worker. - - `fetch_responses`: Fetch the latest responses from engine. - - `fetch_stats`: Fetch the latest stats from engine. - - `fetch_kv_cache_events`: Fetch the latest kv cache events from engine. - - `shutdown`: Shutdown the worker. + Provides: + - RPC server initialization + - Response queue management + - Async response fetching methods + - Shutdown logic for RPC components + + The inheriting class should call init_rpc_worker() in its __init__. """ # Number of RPC server workers NUM_WORKERS = 6 - def __init__( - self, - engine: Union[Path, Engine], - executor_config: Optional[tllm.ExecutorConfig] = None, - is_llm_executor: Optional[bool] = None, - batched_logits_processor: Optional[BatchedLogitsProcessor] = None, - postproc_worker_config: Optional[PostprocWorkerConfig] = None, - hf_model_dir: Optional[Path] = None, - tokenizer: Optional[TokenizerBase] = None, - llm_args: Optional[BaseLlmArgs] = None, - ) -> None: - super().__init__( - engine=engine, - executor_config=executor_config, - is_llm_executor=is_llm_executor, - llm_args=llm_args, - batched_logits_processor=batched_logits_processor, - postproc_worker_config=postproc_worker_config, - hf_model_dir=hf_model_dir, - tokenizer=tokenizer, - ) + def init_rpc_worker(self, rank: int, rpc_addr: Optional[str]): + if rpc_addr is None: + raise RuntimeError( + "RPC mode enabled but no rpc_addr provided to worker") - # Extract garbage_collection_gen0_threshold from llm_args if available - self.garbage_collection_gen0_threshold = ( - llm_args.garbage_collection_gen0_threshold if llm_args is not None - and hasattr(llm_args, 'garbage_collection_gen0_threshold') else - None) + self.rank = rank self.shutdown_event = Event() - self._response_queue = Queue() self.set_result_queue(self._response_queue) + self.rpc_server = None + self.rpc_addr = rpc_addr + + def start_rpc_server(self): + if self.rank == 0: + self.rpc_server = RPCServer(self, + num_workers=RpcWorkerMixin.NUM_WORKERS) + self.rpc_server.bind(self.rpc_addr) + self.rpc_server.start() + def submit(self, request: GenerationRequest): """ Submits a request to the worker. """ with nvtx_range_debug("RpcWorker.submit", @@ -78,7 +65,8 @@ def submit(self, request: GenerationRequest): super().submit(request) def fetch_responses(self, timeout: Optional[float] = None) -> list: - logger_debug(f"RpcWorker {mpi_rank()} is fetching responses", + """Fetch responses from the response queue (blocking).""" + logger_debug(f"RpcWorker {self.rank} is fetching responses", color="yellow") with nvtx_range_debug("RpcWorker.fetch_responses", color="orange", @@ -98,8 +86,9 @@ def fetch_responses(self, timeout: Optional[float] = None) -> list: async def fetch_responses_async(self, timeout: Optional[float] = None) -> list: + """Async version of fetch_responses using asyncio.to_thread.""" # A really async version of fetch_responses - logger_debug(f"RpcWorker {mpi_rank()} is fetching responses async", + logger_debug(f"RpcWorker {self.rank} is fetching responses async", color="yellow") # First, await any pending responses without blocking the event loop @@ -107,30 +96,51 @@ async def fetch_responses_async(self, timeout=timeout) return responses - async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: - return await asyncio.to_thread(self.fetch_stats) - - async def fetch_kv_cache_events_async(self, - timeout: Optional[float] = None - ) -> list: - return await asyncio.to_thread(self.fetch_kv_cache_events) - - # for streaming performance async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: while not self.shutdown_event.is_set(): responses = await self.fetch_responses_async() if responses: # Only yield if there are actual responses logger_debug( - f"RpcWorker {mpi_rank()} is yielding responses: {responses}", + f"RpcWorker {self.rank} is yielding responses: {responses}", color="yellow") yield responses # batching the responses to opt IPC performance else: # Small delay to prevent busy waiting when no responses await asyncio.sleep(0) logger_debug( - f"RpcWorker {mpi_rank()} quitting fetch_responses_loop_async", + f"RpcWorker {self.rank} quitting fetch_responses_loop_async", color="yellow") + async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: + """Async version of fetch_stats using asyncio.to_thread.""" + return await asyncio.to_thread(self.fetch_stats) + + async def fetch_kv_cache_events_async(self, + timeout: Optional[float] = None + ) -> list: + """Async version of fetch_kv_cache_events using asyncio.to_thread.""" + return await asyncio.to_thread(self.fetch_kv_cache_events) + + async def fetch_stats_loop_async( + self, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + async for data in self._generic_fetch_loop_async( + fetch_method=self.fetch_stats_async, + serializer=self._stats_serializer, + method_name="fetch_stats_loop_async", + timeout=timeout): + yield data + + async def fetch_kv_cache_events_loop_async( + self, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + async for data in self._generic_fetch_loop_async( + fetch_method=self.fetch_kv_cache_events_async, + serializer=self._kv_cache_events_serializer, + method_name="fetch_kv_cache_events_loop_async", + timeout=timeout): + yield data + async def _generic_fetch_loop_async( self, fetch_method, @@ -152,28 +162,54 @@ async def _generic_fetch_loop_async( # Always yield data, even if empty, to prevent the client looks like hanging # TODO: Remove the empty data to reduce the IPC overhead yield [serializer(item) for item in data] - logger_debug(f"RpcWorker {mpi_rank()} quitting {method_name}", + logger_debug(f"RpcWorker {self.rank} quitting {method_name}", color="yellow") - async def fetch_stats_loop_async( - self, - timeout: Optional[float] = None) -> AsyncGenerator[list, None]: - async for data in self._generic_fetch_loop_async( - fetch_method=self.fetch_stats_async, - serializer=self._stats_serializer, - method_name="fetch_stats_loop_async", - timeout=timeout): - yield data - async def fetch_kv_cache_events_loop_async( - self, - timeout: Optional[float] = None) -> AsyncGenerator[list, None]: - async for data in self._generic_fetch_loop_async( - fetch_method=self.fetch_kv_cache_events_async, - serializer=self._kv_cache_events_serializer, - method_name="fetch_kv_cache_events_loop_async", - timeout=timeout): - yield data +class RpcWorker(RpcWorkerMixin, BaseWorker): + """ + A RPC wrapper for the BaseWorker class. + + Actions: + - `setup_engine`: Setup the engine. + - `submit`: Submit a request to the worker. + - `fetch_responses`: Fetch the latest responses from engine. + - `fetch_stats`: Fetch the latest stats from engine. + - `fetch_kv_cache_events`: Fetch the latest kv cache events from engine. + - `shutdown`: Shutdown the worker. + """ + + def __init__( + self, + engine: Union[Path, Engine], + executor_config: Optional[tllm.ExecutorConfig] = None, + is_llm_executor: Optional[bool] = None, + batched_logits_processor: Optional[BatchedLogitsProcessor] = None, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, + ) -> None: + super().__init__( + engine=engine, + executor_config=executor_config, + is_llm_executor=is_llm_executor, + llm_args=llm_args, + batched_logits_processor=batched_logits_processor, + postproc_worker_config=postproc_worker_config, + hf_model_dir=hf_model_dir, + tokenizer=tokenizer, + ) + + # Extract garbage_collection_gen0_threshold from llm_args if available + self.garbage_collection_gen0_threshold = ( + llm_args.garbage_collection_gen0_threshold if llm_args is not None + and hasattr(llm_args, 'garbage_collection_gen0_threshold') else + None) + self.shutdown_event = Event() + + self._response_queue = Queue() + self.set_result_queue(self._response_queue) def setup_engine(self): # Force all the ranks to wait here, and start creating the executor simultaneously. @@ -183,13 +219,6 @@ def setup_engine(self): super().setup_engine() - def shutdown(self): - logger_debug(f"RPC worker {mpi_rank()} is shutting down", - color="yellow") - self.shutdown_event.set() - super().shutdown() - logger_debug(f"RPC worker {mpi_rank()} is shutdown", color="yellow") - def start(self): pass @@ -247,6 +276,13 @@ def main_task( worker.shutdown_event.wait() rpc_server.shutdown() + def shutdown(self): + logger_debug(f"RPC worker {mpi_rank()} is shutting down", + color="yellow") + self.shutdown_event.set() + super().shutdown() + logger_debug(f"RPC worker {mpi_rank()} is shutdown", color="yellow") + def __enter__(self): return self diff --git a/tests/integration/defs/examples/test_ray.py b/tests/integration/defs/examples/test_ray.py index 676fd476caa..ffc3f3f60fb 100644 --- a/tests/integration/defs/examples/test_ray.py +++ b/tests/integration/defs/examples/test_ray.py @@ -12,7 +12,11 @@ def ray_example_root(llm_root): return example_root -def test_llm_inference_async_ray(ray_example_root, llm_venv): +@pytest.mark.parametrize("use_rpc", [True, False], ids=["rpc", "no_rpc"]) +def test_llm_inference_async_ray(ray_example_root, llm_venv, monkeypatch, + use_rpc): + if use_rpc: + monkeypatch.setenv("TLLM_RAY_USE_RPC", "1") script_path = os.path.join(ray_example_root, "llm_inference_async_ray.py") model_path = f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0" venv_check_call(llm_venv, [script_path, "--model", model_path]) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index dc43383222c..6d267b451c6 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -131,7 +131,7 @@ l0_h100: - unittest/_torch/executor - unittest/_torch/ray_orchestrator/single_gpu - unittest/llmapi/test_llm_pytorch.py - - examples/test_ray.py::test_llm_inference_async_ray + - examples/test_ray.py::test_llm_inference_async_ray[no_rpc] - condition: ranges: system_gpu_count: