diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f3b444a9760..248599835d7 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1,13 +1,11 @@ import dataclasses import datetime import functools -import gc import os import pickle # nosec B403 import threading import time import traceback -import weakref from contextlib import contextmanager from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -59,10 +57,6 @@ # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP" -# Environment variable to enable garbage collection profiling. -# Set to "1" to enable recording of garbage collection events during profiling. -PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC" - # Environment variable to enable PyTorch profiler tracing. # Set to a path to save detailed tracing of PyTorch operations. PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" @@ -97,40 +91,6 @@ def _load_iteration_indexes(env_var: str): return frozenset(starts), frozenset(stops) -class _GCNvtxHandle: - pass - - -def _gc_nvtx_watcher(): - enabled = os.environ.get(PROFILE_RECORD_GC_ENV_VAR_NAME, None) - if not enabled: - return None - - range_id: Optional[int] = None - - def gc_callback(phase, _): - nonlocal range_id - if phase == "start": - assert range_id is None, "Unexpected state in GC callback: another GC while last GC not finished?" - range_id = torch.cuda.nvtx.range_start("Python GC") - elif phase == "stop": - assert range_id is not None, "Unexpected state in GC callback: no active GC but got GC finished?" - torch.cuda.nvtx.range_end(range_id) - range_id = None - - gc.callbacks.append(gc_callback) - - def gc_cleanup(callback): - try: - gc.callbacks.remove(callback) - except ValueError: - pass - - handle = _GCNvtxHandle() - weakref.finalize(handle, gc_cleanup, gc_callback) - return handle - - @dataclasses.dataclass class BatchState: sample_state: SampleState @@ -178,7 +138,6 @@ def __init__(self, # profile config self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes( PROFILE_START_STOP_ENV_VAR_NAME) - self.gc_nvtx_watcher_handle = _gc_nvtx_watcher() # related modules self.resource_manager = resource_manager diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index a566089124d..a3bf1024eae 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -918,6 +918,18 @@ def nvtx_range_debug(msg: str, return _null_context_manager() +def nvtx_mark_debug(msg: str, + color: str = "grey", + domain: str = "TensorRT-LLM", + category: Optional[str] = None) -> None: + """ + Creates an NVTX marker for debugging purposes. + """ + if os.getenv("TLLM_LLMAPI_ENABLE_NVTX", "0") == "1" or \ + os.getenv("TLLM_NVTX_DEBUG", "0") == "1": + nvtx_mark(msg, color=color, domain=domain, category=category) + + def nvtx_mark(msg: str, color: str = "grey", domain: str = "TensorRT-LLM", @@ -1195,3 +1207,71 @@ def is_device_integrated() -> bool: if not torch.cuda.is_available(): return False return torch.cuda.get_device_properties().is_integrated + + +# Environment variable to enable garbage collection profiling. +# Set to "1" to enable recording of garbage collection events during profiling. +PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC" + + +class _GCNvtxHandle: + """Handle object for GC NVTX watcher to keep it alive.""" + + +# Singleton for the GC NVTX watcher handle. +_gc_watcher_handle: Optional[_GCNvtxHandle] = None + + +def _setup_gc_nvtx_profiling() -> Optional[_GCNvtxHandle]: + """ + Set up NVTX range markers for Python garbage collection events (singleton). + This helps in profiling to visualize when GC occurs during execution. + + This function is called automatically at module import time. The environment + variable TLLM_PROFILE_RECORD_GC must be set before importing this module. + + This is an internal function and should not be called directly by users. + + Returns: + _GCNvtxHandle or None: A handle object that keeps the GC callback alive, + or None if GC profiling is not enabled. + """ + global _gc_watcher_handle + + # Return existing handle if already initialized + if _gc_watcher_handle is not None: + return _gc_watcher_handle + + enabled = os.environ.get(PROFILE_RECORD_GC_ENV_VAR_NAME, None) + if not enabled: + return None + + range_id: Optional[int] = None + + def gc_callback(phase, _): + nonlocal range_id + if phase == "start": + assert range_id is None, "Unexpected state in GC callback: another GC while last GC not finished?" + range_id = torch.cuda.nvtx.range_start("Python GC") + elif phase == "stop": + assert range_id is not None, "Unexpected state in GC callback: no active GC but got GC finished?" + torch.cuda.nvtx.range_end(range_id) + range_id = None + + gc.callbacks.append(gc_callback) + + def gc_cleanup(callback): + try: + gc.callbacks.remove(callback) + except ValueError: + pass + + handle = _GCNvtxHandle() + weakref.finalize(handle, gc_cleanup, gc_callback) + + _gc_watcher_handle = handle + return handle + + +# Initialize GC NVTX profiling singleton at module import time +_setup_gc_nvtx_profiling() diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index d037797e7fa..3a35008daba 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -from importlib.metadata import version from pathlib import Path from random import choices, shuffle from typing import Dict, List, Tuple, Union @@ -170,7 +169,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, backend = params.get("backend", "pytorch") return { - "sw_version": version("tensorrt_llm"), + "sw_version": "1.2", "model_path": model_path, "settings_config": { "max_batch_size": int(max_batch_size), diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 800715e3011..e7ab9192ad1 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -368,6 +368,7 @@ def _create_ray_executor( is_llm_executor: bool, tp_size: int, ): + logger.warning(f"Orchestrator is creating Ray executor") from .ray_executor import RayExecutor return RayExecutor(worker_kwargs, @@ -386,6 +387,7 @@ def _create_rpc_executor( ): """Create RPC-based executor (GenerationExecutorRpcProxy).""" from .rpc_proxy import GenerationExecutorRpcProxy + logger.warning(f"Orchestrator is creating RPC executor") return GenerationExecutorRpcProxy( worker_kwargs, model_world_size=model_world_size, @@ -408,6 +410,7 @@ def _create_ipc_executor( use_worker: If True, creates GenerationExecutorWorker (single process). If False, creates GenerationExecutorProxy (multi-process with IPC). """ + logger.warning(f"Orchestrator is creating IPC executor") if use_worker: from .worker import GenerationExecutorWorker return GenerationExecutorWorker(**worker_kwargs, diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 03d43ce1d0b..07657786277 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -4,7 +4,13 @@ import uuid from typing import Any, AsyncIterator, Dict, Optional -from ...llmapi.utils import AsyncQueue, _SyncQueue, logger_debug +import zmq + +from tensorrt_llm._utils import (customized_gc_thresholds, nvtx_mark_debug, + nvtx_range_debug) + +from ...llmapi.utils import (AsyncQueue, _SyncQueue, enable_llmapi_debug, + logger_debug) from ...logger import logger from ..ipc import ZeroMqQueue from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse, @@ -90,7 +96,8 @@ def __init__(self, self._client_socket = ZeroMqQueue(address=(address, hmac_key), is_server=False, is_async=True, - use_hmac_encryption=False) + use_hmac_encryption=False, + socket_type=zmq.DEALER) self._pending_futures = {} # map request_id to the queue for streaming responses self._streaming_queues: Dict[str, AsyncQueue] = {} @@ -100,9 +107,9 @@ def __init__(self, self._server_stopped = False self._closed = False - self._stop_event = None self._loop = None self._loop_thread = None + self._reader_asyncio_task = None # Track the asyncio task for proper cancellation logger_debug(f"RPC Client initialized. Connected to {self._address}") @@ -120,141 +127,216 @@ def close(self): if self._closed: return - # stop the main loop self._closed = True logger_debug("RPC Client closing") - if self._stop_event and self._loop: - # Use call_soon_threadsafe since set() is not a coroutine - self._loop.call_soon_threadsafe(self._stop_event.set) - - if self._reader_task: - try: - self._reader_task.result(timeout=1.0) - except concurrent.futures.TimeoutError: - logger.warning( - "Reader task did not exit gracefully, cancelling") - self._reader_task.cancel() - except Exception as e: - # Task might have already finished or been cancelled - logger_debug(f"Reader task cleanup: {e}") + # Cancel the reader task first to avoid socket closure errors + if self._reader_task and not self._reader_task.done(): + if self._loop and self._loop.is_running( + ) and self._reader_asyncio_task: + try: + # Cancel the asyncio task in its event loop + async def cancel_reader_task(): + if self._reader_asyncio_task and not self._reader_asyncio_task.done( + ): + self._reader_asyncio_task.cancel() + try: + await self._reader_asyncio_task + except asyncio.CancelledError: + pass # Expected + + cancel_future = asyncio.run_coroutine_threadsafe( + cancel_reader_task(), self._loop) + cancel_future.result(timeout=2.0) + logger_debug("Reader task cancelled successfully") + except concurrent.futures.TimeoutError: + logger.warning("Reader task did not exit gracefully") + except Exception as e: + logger_debug(f"Reader task cleanup: {e}") self._reader_task = None + self._reader_asyncio_task = None + + # Now close the socket after reader has stopped + if self._client_socket: + self._client_socket.close() + self._client_socket = None + # Stop the event loop if self._loop and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) if self._loop_thread: - self._loop_thread.join() + self._loop_thread.join(timeout=2.0) self._loop_thread = None + if self._executor: self._executor.shutdown(wait=True) - if self._client_socket: - self._client_socket.close() - self._client_socket = None - logger_debug("RPC Client closed") - async def _response_reader(self): - """Task to read responses from the socket and set results on futures.""" - logger_debug("Response reader started") + def _handle_streaming_response(self, response: RPCResponse): + """Handle a streaming response by putting it in the appropriate queue. - while not self._stop_event.is_set(): - try: - # Use wait_for with a short timeout to periodically check stop event - try: - response: RPCResponse = await asyncio.wait_for( - self._client_socket.get_async(), - timeout=0.1 # Check stop event every 100ms - ) - except asyncio.TimeoutError: - # Timeout is expected - just check stop event and continue - continue - - logger_debug(f"RPC Client received response: {response}") + Args: + response: The streaming response to handle + """ + assert response.stream_status in [ + 'start', 'data', 'end', 'error' + ], f"Invalid stream status: {response.stream_status}" + + queue = self._streaming_queues.get(response.request_id) + if queue: + # put to the sync queue, as the current event loop is + # different from the one in call_async or call_streaming + assert isinstance(queue, AsyncQueue) + if enable_llmapi_debug() or logger.level == 'debug': logger_debug( - f"Response request_id: {response.request_id}, is_streaming: {response.is_streaming}" + f"RPC Client putting response to AsyncQueue: status={response.stream_status}, request_id={response.request_id}" ) - logger_debug( - f"Pending futures: {list(self._pending_futures.keys())}") - - # Handle streaming responses - if response.is_streaming: - assert response.stream_status in [ - 'start', 'data', 'end', 'error' - ], f"Invalid stream status: {response.stream_status}" - queue = self._streaming_queues.get(response.request_id) - if queue: - # put to the sync queue, as the current event loop is - # different from the one in call_async or call_streaming - assert isinstance(queue, AsyncQueue) - logger_debug( - f"RPC Client putting response to AsyncQueue: {response}" - ) - queue.sync_q.put(response) - # Clean up if stream ended - if response.stream_status in ['end', 'error']: - self._streaming_queues.pop(response.request_id, - None) - else: - # Handle regular responses - logger_debug( - f"Handling regular response for request_id: {response.request_id}" - ) - if future_info := self._pending_futures.get( - response.request_id): - future, target_loop = future_info - logger_debug( - f"Found future for request_id: {response.request_id}, future done: {future.done()}" - ) + queue.sync_q.put(response) + # Clean up if stream ended + if response.stream_status in ['end', 'error']: + self._streaming_queues.pop(response.request_id, None) + def _handle_regular_response(self, response: RPCResponse): + """Handle a regular (non-streaming) response by setting the future result. + + Args: + response: The response to handle + """ + if future_info := self._pending_futures.get(response.request_id): + future, target_loop = future_info + + if not future.done(): + + def safe_set_result(): + """Safely set result on future, handling race conditions.""" + try: if not future.done(): if response.error is None: - logger_debug( - f"Setting result for request_id: {response.request_id}, result: {response.result}" - ) - target_loop.call_soon_threadsafe( - future.set_result, response.result) + future.set_result(response.result) else: - # Use the original RPCError from the response - logger_debug( - f"Setting exception for request_id: {response.request_id}, error: {response.error}" - ) - target_loop.call_soon_threadsafe( - future.set_exception, response.error) + future.set_exception(response.error) + except asyncio.InvalidStateError: + # Future was cancelled or completed between the check and set + # This is expected in high-load scenarios, just log and continue + if enable_llmapi_debug() or logger.level == 'debug': + logger_debug( + f"Future already done for request_id: {response.request_id}, skipping" + ) + + if enable_llmapi_debug() or logger.level == 'debug': + if response.error is None: + logger_debug( + f"Setting result for request_id: {response.request_id}" + ) else: logger_debug( - f"No future found for request_id: {response.request_id}" + f"Setting exception for request_id: {response.request_id}, error: {response.error}" ) - self._pending_futures.pop(response.request_id, None) - - except asyncio.CancelledError: - # Still handle cancellation for backward compatibility - logger_debug("Response reader cancelled") - break - except Exception as e: - logger.error(f"Exception in RPC response reader: {e}") - # Propagate exception to all pending futures - for (future, target_loop) in self._pending_futures.values(): - - if not future.done(): - target_loop.call_soon_threadsafe( - future.set_exception, e) - # Also signal error to streaming queues - for queue in self._streaming_queues.values(): - await queue.put(RPCResponse("", None, e, False, 0, 'error')) - break - - logger_debug("Response reader exiting gracefully") - self._reader_task = None + + target_loop.call_soon_threadsafe(safe_set_result) + else: + if enable_llmapi_debug() or logger.level == 'debug': + logger_debug( + f"No future found for request_id: {response.request_id}") + + self._pending_futures.pop(response.request_id, None) + + async def _handle_reader_exception(self, exception: Exception): + """Propagate an exception to all pending futures and streaming queues. + + Args: + exception: The exception to propagate + """ + logger.error(f"Exception in RPC response reader: {exception}") + + # Propagate exception to all pending futures + for (future, target_loop) in self._pending_futures.values(): + if not future.done(): + + def safe_set_exception(f=future, exc=exception): + """Safely set exception on future, handling race conditions.""" + try: + if not f.done(): + f.set_exception(exc) + except asyncio.InvalidStateError: + # Future was cancelled or completed, this is fine + pass + + target_loop.call_soon_threadsafe(safe_set_exception) + + # Also signal error to streaming queues + for queue in self._streaming_queues.values(): + await queue.put(RPCResponse("", None, exception, False, 0, 'error')) + + async def _wait_for_response(self) -> RPCResponse: + """Wait for a response from the socket. + + Returns: + RPCResponse from the server + """ + # Directly await the socket - cancellation will be handled by task cancellation + return await self._client_socket.get_async() + + async def _response_reader(self): + """Task to read responses from the socket and set results on futures.""" + logger_debug("Response reader started") + + try: + with customized_gc_thresholds(10000): + while True: + with nvtx_range_debug("response_reader", + color="cyan", + category="RPC"): + try: + response = await self._wait_for_response() + + nvtx_mark_debug( + f"RPC.response.{'streaming' if response.is_streaming else 'sync'}", + color="black", + category="RPC") + + # Optimize: Check debug flag before expensive string operations + # This avoids holding GIL for f-string evaluation when debug is disabled + if enable_llmapi_debug() or logger.level == 'debug': + logger_debug( + f"RPC Client received response: request_id={response.request_id}, " + f"is_streaming={response.is_streaming}, " + f"pending_futures={len(self._pending_futures)}" + ) + + with nvtx_range_debug("handle_response", + color="purple", + category="RPC"): + if response.is_streaming: + self._handle_streaming_response(response) + else: + self._handle_regular_response(response) + + except Exception as e: + await self._handle_reader_exception(e) + break + + except asyncio.CancelledError: + logger_debug("Response reader cancelled") + finally: + logger_debug("Response reader exiting gracefully") + self._reader_task = None + self._reader_asyncio_task = None def _start_response_reader_lazily(self): if self._reader_task is None or self._reader_task.done(): # Ensure we have a persistent background loop self._ensure_event_loop() - # Always create the reader task on the persistent loop - future = asyncio.run_coroutine_threadsafe(self._response_reader(), - self._loop) + + # Wrapper to track the asyncio task + async def run_reader(): + self._reader_asyncio_task = asyncio.current_task() + await self._response_reader() + + # Start the reader task on the persistent loop + future = asyncio.run_coroutine_threadsafe(run_reader(), self._loop) # Store the concurrent.futures.Future self._reader_task = future @@ -269,9 +351,11 @@ async def _call_async(self, method_name, *args, **kwargs): Returns: The result of the remote method call """ - logger_debug( - f"RPC client calling method: {method_name} with args: {args} and kwargs: {kwargs}" - ) + if enable_llmapi_debug() or logger.level == 'debug': + logger_debug(f"RPC client calling method: {method_name}") + nvtx_mark_debug(f"RPC.async.{method_name}", + color="yellow", + category="RPC") if self._server_stopped: raise RPCCancelled("Server is shutting down, request cancelled") @@ -287,7 +371,6 @@ async def _call_async(self, method_name, *args, **kwargs): kwargs, need_response, timeout=timeout) - logger_debug(f"RPC client sending request: {request}") await self._client_socket.put_async(request) if not need_response: @@ -295,28 +378,17 @@ async def _call_async(self, method_name, *args, **kwargs): loop = asyncio.get_running_loop() future = loop.create_future() - logger_debug( - f"RPC Client _call_async: Created future for request_id: {request_id} in loop: {id(loop)}" - ) self._pending_futures[request_id] = (future, loop) - logger_debug( - f"RPC Client _call_async: Stored future in pending_futures") try: # If timeout, the remote call should return a timeout error timely, # so we add 1 second to the timeout to ensure the client can get # that result. - logger_debug( - f"RPC Client _call_async: Awaiting future for request_id: {request_id}" - ) if timeout is None: res = await future else: # Add 1 second to the timeout to ensure the client can get res = await asyncio.wait_for(future, timeout) - logger_debug( - f"RPC Client _call_async: Got result for request_id: {request_id}: {res}" - ) return res except RPCCancelled: self._server_stopped = True @@ -336,7 +408,6 @@ def _ensure_event_loop(self): def run_loop(): asyncio.set_event_loop(self._loop) - self._stop_event = asyncio.Event() self._loop.run_forever() self._loop_thread = threading.Thread(target=run_loop, @@ -350,19 +421,15 @@ def run_loop(): def _call_sync(self, method_name, *args, **kwargs): """Synchronous version of RPC call.""" - logger_debug( - f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}" - ) + if enable_llmapi_debug() or logger.level == 'debug': + logger_debug(f"RPC Client calling method: {method_name}") + nvtx_mark_debug(f"RPC.sync.{method_name}", + color="green", + category="RPC") self._ensure_event_loop() - logger_debug( - f"RPC Client _call_sync: Creating future for {method_name}") future = asyncio.run_coroutine_threadsafe( self._call_async(method_name, *args, **kwargs), self._loop) - logger_debug( - f"RPC Client _call_sync: Waiting for result of {method_name}") result = future.result() - logger_debug( - f"RPC Client _call_sync: Got result for {method_name}: {result}") return result def _call_future(self, name: str, *args, @@ -378,6 +445,7 @@ def _call_future(self, name: str, *args, Returns: A Future object that can be used to retrieve the result """ + nvtx_mark_debug(f"RPC.future.{name}", color="blue", category="RPC") def _async_to_sync(): self._ensure_event_loop() @@ -400,6 +468,8 @@ async def _call_streaming(self, name: str, *args, Yields: Results from the remote async generator """ + nvtx_mark_debug(f"RPC.streaming.{name}", color="red", category="RPC") + if self._server_stopped: raise RPCCancelled("Server is shutting down, request cancelled") @@ -428,24 +498,21 @@ async def _call_streaming(self, name: str, *args, # Read streaming responses while True: - logger_debug(f"RPC Client _call_streaming waiting for response", - color="green") if timeout is None: response = await queue.get() else: response = await asyncio.wait_for(queue.get(), timeout=timeout) - logger_debug( - f"RPC Client _call_streaming received [{response.stream_status}] response: {response}", - color="green") + if enable_llmapi_debug() or logger.level == 'debug': + logger_debug( + f"RPC Client _call_streaming received [{response.stream_status}] response", + color="green") + if response.stream_status == 'start': # Start of stream continue elif response.stream_status == 'data': - logger_debug( - f"RPC Client _call_streaming received data: {response.result}", - color="green") yield response.result elif response.stream_status == 'end': # End of stream diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index 268bb6012f2..f96890bbf76 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -7,6 +7,8 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional +import zmq + from ...llmapi.utils import ManagedThread, logger_debug from ...logger import logger from ..ipc import ZeroMqQueue @@ -90,7 +92,8 @@ def bind(self, address="tcp://*:5555"): self._client_socket = ZeroMqQueue(address=(address, self._hmac_key), is_server=True, is_async=True, - use_hmac_encryption=False) + use_hmac_encryption=False, + socket_type=zmq.ROUTER) logger.info(f"RPC Server bound to {self._address}") def shutdown(self, is_remote_call: bool = False): diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index d2208f32d1d..0fb67d5baaa 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -4,6 +4,7 @@ import threading from typing import Optional +from .._utils import nvtx_range_debug from ..llmapi.mpi_session import MpiPoolSession, MpiSession from ..llmapi.tracer import global_tracer from ..llmapi.utils import AsyncQueue, _SyncQueue, logger_debug @@ -64,6 +65,7 @@ def __init__( self.main_loop_task_obj = None self.main_loop = None + self.main_loop_thread = None self.launch_workers() @@ -148,7 +150,8 @@ def _run_main_loop_task(): self.main_loop.close() self.main_loop_thread = threading.Thread(target=_run_main_loop_task, - daemon=True) + daemon=True, + name="rpc_proxy_main_loop") self.main_loop_thread.start() atexit.register(self.shutdown) @@ -287,7 +290,10 @@ def submit(self, request: GenerationRequest) -> GenerationResult: logprob_params = self._get_logprob_params(request) # submit is a fire-and-forget operation, don't need to wait for response - self.rpc_client.submit(request).remote(need_response=False) + with nvtx_range_debug("GenerationExecutorRpcProxy.submit", + color="green", + category="Proxy"): + self.rpc_client.submit(request).remote(need_response=False) result = GenerationResult( request, @@ -333,7 +339,11 @@ def shutdown(self): logger_debug(f"Error cancelling main loop task: {e}", color="yellow") - self.main_loop_thread.join() + # Only join if we're not calling from the main_loop_thread itself + # (e.g., during garbage collection in that thread) + if self.main_loop_thread and threading.current_thread( + ) != self.main_loop_thread: + self.main_loop_thread.join() # 3. shutdown the mpi session, this should wait until all the PyExecutor # processes are shutdown diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index b08516ceecc..47bcfacdef4 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -4,10 +4,12 @@ from threading import Event from typing import AsyncGenerator, Optional, Union +import nvtx + from tensorrt_llm._utils import mpi_comm from tensorrt_llm.llmapi.utils import enable_llm_debug, logger_debug -from .._utils import mpi_rank +from .._utils import mpi_rank, nvtx_range_debug from ..bindings import executor as tllm from ..builder import Engine from ..llmapi.llm_args import BaseLlmArgs @@ -57,6 +59,7 @@ def __init__( hf_model_dir=hf_model_dir, tokenizer=tokenizer, ) + # Extract garbage_collection_gen0_threshold from llm_args if available self.garbage_collection_gen0_threshold = ( llm_args.garbage_collection_gen0_threshold if llm_args is not None @@ -69,14 +72,20 @@ def __init__( def submit(self, request: GenerationRequest): """ Submits a request to the worker. """ - super().submit(request) + with nvtx_range_debug("RpcWorker.submit", + color="blue", + category="Worker"): + super().submit(request) def fetch_responses(self, timeout: Optional[float] = None) -> list: logger_debug(f"RpcWorker {mpi_rank()} is fetching responses", color="yellow") - # NOTE: This is a blocking call, it will wait for the responses to be available. - responses = super().await_responses(timeout) - self._await_response_helper.responses_handler(responses) + with nvtx_range_debug("RpcWorker.fetch_responses", + color="orange", + category="Worker"): + # NOTE: This is a blocking call, it will wait for the responses to be available. + responses = super().await_responses(timeout) + self._await_response_helper.responses_handler(responses) qsize = self._response_queue.qsize() logger_debug(f"RpcWorker returning {qsize} responses", color="yellow") @@ -198,6 +207,8 @@ def main_task( tokenizer: Optional[TokenizerBase] = None, **kwargs, ) -> None: + nvtx.push_range(f"RpcWorker.main_task_{mpi_rank()}", color="pink") + if enable_llm_debug(): set_level("debug")