Skip to content

Commit 1a6cae3

Browse files
committed
add gc trace in LLM
Signed-off-by: Superjomn <[email protected]>
1 parent 098d12b commit 1a6cae3

File tree

4 files changed

+118
-121
lines changed

4 files changed

+118
-121
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import dataclasses
22
import datetime
33
import functools
4-
import gc
54
import os
65
import pickle # nosec B403
76
import threading
87
import time
98
import traceback
10-
import weakref
119
from contextlib import contextmanager
1210
from typing import Dict, Iterable, List, Optional, Tuple, Union
1311

@@ -22,8 +20,9 @@
2220

2321
from tensorrt_llm._torch.pyexecutor.resource_manager import (
2422
ResourceManagerType, request_context)
25-
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
26-
mpi_disabled, nvtx_range, trace_func)
23+
from tensorrt_llm._utils import (customized_gc_thresholds, gc_nvtx_watcher,
24+
is_trace_enabled, mpi_disabled, nvtx_range,
25+
trace_func)
2726
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
2827
FinishReason, InflightBatchingStats,
2928
IterationStats, KvCacheStats,
@@ -59,10 +58,6 @@
5958
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
6059
PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP"
6160

62-
# Environment variable to enable garbage collection profiling.
63-
# Set to "1" to enable recording of garbage collection events during profiling.
64-
PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC"
65-
6661
# Environment variable to enable PyTorch profiler tracing.
6762
# Set to a path to save detailed tracing of PyTorch operations.
6863
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
@@ -97,40 +92,6 @@ def _load_iteration_indexes(env_var: str):
9792
return frozenset(starts), frozenset(stops)
9893

9994

100-
class _GCNvtxHandle:
101-
pass
102-
103-
104-
def _gc_nvtx_watcher():
105-
enabled = os.environ.get(PROFILE_RECORD_GC_ENV_VAR_NAME, None)
106-
if not enabled:
107-
return None
108-
109-
range_id: Optional[int] = None
110-
111-
def gc_callback(phase, _):
112-
nonlocal range_id
113-
if phase == "start":
114-
assert range_id is None, "Unexpected state in GC callback: another GC while last GC not finished?"
115-
range_id = torch.cuda.nvtx.range_start("Python GC")
116-
elif phase == "stop":
117-
assert range_id is not None, "Unexpected state in GC callback: no active GC but got GC finished?"
118-
torch.cuda.nvtx.range_end(range_id)
119-
range_id = None
120-
121-
gc.callbacks.append(gc_callback)
122-
123-
def gc_cleanup(callback):
124-
try:
125-
gc.callbacks.remove(callback)
126-
except ValueError:
127-
pass
128-
129-
handle = _GCNvtxHandle()
130-
weakref.finalize(handle, gc_cleanup, gc_callback)
131-
return handle
132-
133-
13495
@dataclasses.dataclass
13596
class BatchState:
13697
sample_state: SampleState
@@ -178,7 +139,7 @@ def __init__(self,
178139
# profile config
179140
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
180141
PROFILE_START_STOP_ENV_VAR_NAME)
181-
self.gc_nvtx_watcher_handle = _gc_nvtx_watcher()
142+
self.gc_nvtx_watcher_handle = gc_nvtx_watcher()
182143

183144
# related modules
184145
self.resource_manager = resource_manager

tensorrt_llm/_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,3 +1191,50 @@ def is_device_integrated() -> bool:
11911191
if not torch.cuda.is_available():
11921192
return False
11931193
return torch.cuda.get_device_properties().is_integrated
1194+
1195+
1196+
# Environment variable to enable garbage collection profiling.
1197+
# Set to "1" to enable recording of garbage collection events during profiling.
1198+
PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC"
1199+
1200+
1201+
class _GCNvtxHandle:
1202+
"""Handle object for GC NVTX watcher to keep it alive."""
1203+
1204+
1205+
def gc_nvtx_watcher() -> Optional[_GCNvtxHandle]:
1206+
"""
1207+
Set up NVTX range markers for Python garbage collection events.
1208+
This helps in profiling to visualize when GC occurs during execution.
1209+
1210+
Returns:
1211+
_GCNvtxHandle or None: A handle object that keeps the GC callback alive,
1212+
or None if GC profiling is not enabled.
1213+
"""
1214+
enabled = os.environ.get(PROFILE_RECORD_GC_ENV_VAR_NAME, None)
1215+
if not enabled:
1216+
return None
1217+
1218+
range_id: Optional[int] = None
1219+
1220+
def gc_callback(phase, _):
1221+
nonlocal range_id
1222+
if phase == "start":
1223+
assert range_id is None, "Unexpected state in GC callback: another GC while last GC not finished?"
1224+
range_id = torch.cuda.nvtx.range_start("Python GC")
1225+
elif phase == "stop":
1226+
assert range_id is not None, "Unexpected state in GC callback: no active GC but got GC finished?"
1227+
torch.cuda.nvtx.range_end(range_id)
1228+
range_id = None
1229+
1230+
gc.callbacks.append(gc_callback)
1231+
1232+
def gc_cleanup(callback):
1233+
try:
1234+
gc.callbacks.remove(callback)
1235+
except ValueError:
1236+
pass
1237+
1238+
handle = _GCNvtxHandle()
1239+
weakref.finalize(handle, gc_cleanup, gc_callback)
1240+
return handle

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 63 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from tensorrt_llm._utils import nvtx_mark_debug
1010

11-
from ...llmapi.utils import AsyncQueue, _SyncQueue, logger_debug
11+
from ..._utils import nvtx_range_debug
12+
from ...llmapi.utils import (AsyncQueue, _SyncQueue, enable_llmapi_debug,
13+
logger_debug)
1214
from ...logger import logger
1315
from ..ipc import ZeroMqQueue
1416
from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse,
@@ -175,8 +177,10 @@ def _handle_streaming_response(self, response: RPCResponse):
175177
# put to the sync queue, as the current event loop is
176178
# different from the one in call_async or call_streaming
177179
assert isinstance(queue, AsyncQueue)
178-
logger_debug(
179-
f"RPC Client putting response to AsyncQueue: {response}")
180+
if enable_llmapi_debug() or logger.level == 'debug':
181+
logger_debug(
182+
f"RPC Client putting response to AsyncQueue: status={response.stream_status}, request_id={response.request_id}"
183+
)
180184
queue.sync_q.put(response)
181185
# Clean up if stream ended
182186
if response.stream_status in ['end', 'error']:
@@ -188,32 +192,29 @@ def _handle_regular_response(self, response: RPCResponse):
188192
Args:
189193
response: The response to handle
190194
"""
191-
logger_debug(
192-
f"Handling regular response for request_id: {response.request_id}")
193-
194195
if future_info := self._pending_futures.get(response.request_id):
195196
future, target_loop = future_info
196-
logger_debug(
197-
f"Found future for request_id: {response.request_id}, future done: {future.done()}"
198-
)
199197

200198
if not future.done():
201199
if response.error is None:
202-
logger_debug(
203-
f"Setting result for request_id: {response.request_id}, result: {response.result}"
204-
)
200+
if enable_llmapi_debug() or logger.level == 'debug':
201+
logger_debug(
202+
f"Setting result for request_id: {response.request_id}"
203+
)
205204
target_loop.call_soon_threadsafe(future.set_result,
206205
response.result)
207206
else:
208207
# Use the original RPCError from the response
209-
logger_debug(
210-
f"Setting exception for request_id: {response.request_id}, error: {response.error}"
211-
)
208+
if enable_llmapi_debug() or logger.level == 'debug':
209+
logger_debug(
210+
f"Setting exception for request_id: {response.request_id}, error: {response.error}"
211+
)
212212
target_loop.call_soon_threadsafe(future.set_exception,
213213
response.error)
214214
else:
215-
logger_debug(
216-
f"No future found for request_id: {response.request_id}")
215+
if enable_llmapi_debug() or logger.level == 'debug':
216+
logger_debug(
217+
f"No future found for request_id: {response.request_id}")
217218

218219
self._pending_futures.pop(response.request_id, None)
219220

@@ -256,35 +257,43 @@ async def _response_reader(self):
256257
logger_debug("Response reader started")
257258

258259
while not self._stop_event.is_set():
259-
try:
260-
response = await self._wait_for_response()
261-
if response is None:
262-
continue
263-
264-
nvtx_mark_debug(
265-
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
266-
color="black",
267-
category="RPC")
268-
269-
logger_debug(f"RPC Client received response: {response}")
270-
logger_debug(
271-
f"Response request_id: {response.request_id}, is_streaming: {response.is_streaming}"
272-
)
273-
logger_debug(
274-
f"Pending futures: {list(self._pending_futures.keys())}")
275-
276-
if response.is_streaming:
277-
self._handle_streaming_response(response)
278-
else:
279-
self._handle_regular_response(response)
260+
with nvtx_range_debug("response_reader",
261+
color="cyan",
262+
category="RPC"):
263+
try:
264+
response = await self._wait_for_response()
265+
266+
if response is None:
267+
continue
268+
269+
nvtx_mark_debug(
270+
f"RPC.response.{'streaming' if response.is_streaming else 'sync'}",
271+
color="black",
272+
category="RPC")
280273

281-
except asyncio.CancelledError:
282-
# Still handle cancellation for backward compatibility
283-
logger_debug("Response reader cancelled")
284-
break
285-
except Exception as e:
286-
await self._handle_reader_exception(e)
287-
break
274+
# Optimize: Check debug flag before expensive string operations
275+
# This avoids holding GIL for f-string evaluation when debug is disabled
276+
if enable_llmapi_debug() or logger.level == 'debug':
277+
logger_debug(
278+
f"RPC Client received response: request_id={response.request_id}, "
279+
f"is_streaming={response.is_streaming}, "
280+
f"pending_futures={len(self._pending_futures)}")
281+
282+
with nvtx_range_debug("handle_response",
283+
color="purple",
284+
category="RPC"):
285+
if response.is_streaming:
286+
self._handle_streaming_response(response)
287+
else:
288+
self._handle_regular_response(response)
289+
290+
except asyncio.CancelledError:
291+
# Still handle cancellation for backward compatibility
292+
logger_debug("Response reader cancelled")
293+
break
294+
except Exception as e:
295+
await self._handle_reader_exception(e)
296+
break
288297

289298
logger_debug("Response reader exiting gracefully")
290299
self._reader_task = None
@@ -310,9 +319,8 @@ async def _call_async(self, method_name, *args, **kwargs):
310319
Returns:
311320
The result of the remote method call
312321
"""
313-
logger_debug(
314-
f"RPC client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
315-
)
322+
if enable_llmapi_debug() or logger.level == 'debug':
323+
logger_debug(f"RPC client calling method: {method_name}")
316324
nvtx_mark_debug(f"RPC.async.{method_name}",
317325
color="yellow",
318326
category="RPC")
@@ -331,36 +339,24 @@ async def _call_async(self, method_name, *args, **kwargs):
331339
kwargs,
332340
need_response,
333341
timeout=timeout)
334-
logger_debug(f"RPC client sending request: {request}")
335342
await self._client_socket.put_async(request)
336343

337344
if not need_response:
338345
return None
339346

340347
loop = asyncio.get_running_loop()
341348
future = loop.create_future()
342-
logger_debug(
343-
f"RPC Client _call_async: Created future for request_id: {request_id} in loop: {id(loop)}"
344-
)
345349
self._pending_futures[request_id] = (future, loop)
346-
logger_debug(
347-
f"RPC Client _call_async: Stored future in pending_futures")
348350

349351
try:
350352
# If timeout, the remote call should return a timeout error timely,
351353
# so we add 1 second to the timeout to ensure the client can get
352354
# that result.
353-
logger_debug(
354-
f"RPC Client _call_async: Awaiting future for request_id: {request_id}"
355-
)
356355
if timeout is None:
357356
res = await future
358357
else:
359358
# Add 1 second to the timeout to ensure the client can get
360359
res = await asyncio.wait_for(future, timeout)
361-
logger_debug(
362-
f"RPC Client _call_async: Got result for request_id: {request_id}: {res}"
363-
)
364360
return res
365361
except RPCCancelled:
366362
self._server_stopped = True
@@ -394,22 +390,15 @@ def run_loop():
394390

395391
def _call_sync(self, method_name, *args, **kwargs):
396392
"""Synchronous version of RPC call."""
397-
logger_debug(
398-
f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}"
399-
)
393+
if enable_llmapi_debug() or logger.level == 'debug':
394+
logger_debug(f"RPC Client calling method: {method_name}")
400395
nvtx_mark_debug(f"RPC.sync.{method_name}",
401396
color="green",
402397
category="RPC")
403398
self._ensure_event_loop()
404-
logger_debug(
405-
f"RPC Client _call_sync: Creating future for {method_name}")
406399
future = asyncio.run_coroutine_threadsafe(
407400
self._call_async(method_name, *args, **kwargs), self._loop)
408-
logger_debug(
409-
f"RPC Client _call_sync: Waiting for result of {method_name}")
410401
result = future.result()
411-
logger_debug(
412-
f"RPC Client _call_sync: Got result for {method_name}: {result}")
413402
return result
414403

415404
def _call_future(self, name: str, *args,
@@ -478,24 +467,21 @@ async def _call_streaming(self, name: str, *args,
478467

479468
# Read streaming responses
480469
while True:
481-
logger_debug(f"RPC Client _call_streaming waiting for response",
482-
color="green")
483470
if timeout is None:
484471
response = await queue.get()
485472
else:
486473
response = await asyncio.wait_for(queue.get(),
487474
timeout=timeout)
488475

489-
logger_debug(
490-
f"RPC Client _call_streaming received [{response.stream_status}] response: {response}",
491-
color="green")
476+
if enable_llmapi_debug() or logger.level == 'debug':
477+
logger_debug(
478+
f"RPC Client _call_streaming received [{response.stream_status}] response",
479+
color="green")
480+
492481
if response.stream_status == 'start':
493482
# Start of stream
494483
continue
495484
elif response.stream_status == 'data':
496-
logger_debug(
497-
f"RPC Client _call_streaming received data: {response.result}",
498-
color="green")
499485
yield response.result
500486
elif response.stream_status == 'end':
501487
# End of stream

0 commit comments

Comments
 (0)