Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,15 @@ def should_stop_processing(self):
return self.is_shutdown and len(self.active_requests) == 0 and \
self.executor_request_queue.get_waiting_queue_size() == 0

def _get_ranked_trace_path(self):
"""Return a per-rank torch trace path based on TLLM_TORCH_PROFILE_TRACE."""
rank = getattr(self.dist, "rank", 0)
trace_path = self.torch_trace_path
if os.path.isdir(trace_path):
return os.path.join(trace_path, f"trace_{rank}.json")
base, ext = os.path.splitext(trace_path)
return f"{base}_{rank}{ext or '.json'}"

@contextmanager
def _profiler(self):
it = -1
Expand All @@ -518,12 +527,12 @@ def _profiler(self):
start_event_2 = None
end_event_2 = torch.cuda.Event(enable_timing=True)
prev_device_step_time = None

torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None)
self.torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None)
profile_start_stop = os.environ.get(PROFILE_START_STOP_ENV_VAR_NAME,
None)
enable_torch_trace = bool(torch_trace_path and profile_start_stop)
if torch_trace_path and profile_start_stop is None:
enable_torch_trace = self.torch_trace_path and profile_start_stop
if self.torch_trace_path is not None and profile_start_stop is None:
logger.warning(
f"{PROFILE_START_STOP_ENV_VAR_NAME} environment variable "
"needs to be set to enable the torch trace. Example to profile "
Expand All @@ -546,6 +555,7 @@ def profile_step():
assert enabled, "Inconsistent CUDA profiling state"
if enable_torch_trace:
torch_profiler.stop()
torch_trace_path = self._get_ranked_trace_path()
torch_profiler.export_chrome_trace(torch_trace_path)
logger.info(f"Profiling stopped at iteration {it}, "
f"trace saved to {torch_trace_path}")
Expand Down Expand Up @@ -612,6 +622,7 @@ def profile_step():
# Stop on early exit / exception
if enable_torch_trace:
torch_profiler.stop()
torch_trace_path = self._get_ranked_trace_path()
torch_profiler.export_chrome_trace(torch_trace_path)
logger.info(f"Profiling stopped at iteration {it}, "
f"trace saved to {torch_trace_path}")
Expand Down