Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 100 additions & 90 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import bisect
import contextlib
import weakref
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import torch

from tensorrt_llm.mapping import Mapping

from ...inputs.multimodal import MultimodalParams
from ..attention_backend.interface import AttentionMetadata
from ..distributed import MPIDist
from ..expert_statistic import ExpertStatistic
from ..modules.multi_stream_utils import with_multi_stream
from ..speculative import SpecMetadata
from ..utils import make_weak_ref, piecewise_cuda_graph
from .resource_manager import ResourceManager, ResourceManagerType
from .scheduler import ScheduledRequests

if TYPE_CHECKING:
from .model_engine import PyTorchModelEngine

# A large prime number used for dummy request IDs to avoid collisions
CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1

Expand All @@ -29,37 +30,62 @@ class CUDAGraphRunner:
"""
WARMUP_STEPS = 2

def __init__(self, engine: "PyTorchModelEngine"):
self.engine_ref = weakref.ref(engine)

# High-level configuration
config = engine.pytorch_backend_config
self.enabled = config.use_cuda_graph
self.padding_enabled = config.cuda_graph_padding_enabled
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
self.max_beam_width = engine.max_beam_width
self.spec_config = engine.spec_config

def __init__(
self,
*,
use_cuda_graph: bool,
cuda_graph_padding_enabled: bool,
supported_batch_sizes: list[int],
max_supported_batch_size: int,
max_batch_size: int,
max_beam_width: int,
max_draft_len: int,
max_num_tokens: int,
use_mrope: bool,
spec_config: Optional["DecodingBaseConfig"],
cuda_graph_mem_pool: Optional[int],
enable_attention_dp: bool,
mapping: Mapping,
dist: Optional[MPIDist],
kv_cache_manager_key: ResourceManagerType,
):
# --- High-level configuration passed from the engine ---
self.enabled = use_cuda_graph
self.padding_enabled = cuda_graph_padding_enabled
self.supported_batch_sizes = supported_batch_sizes
self.max_supported_batch_size = max_supported_batch_size
self.max_batch_size = max_batch_size
self.max_beam_width = max_beam_width
self.max_draft_len = max_draft_len
self.max_num_tokens = max_num_tokens
self.use_mrope = use_mrope
self.spec_config = spec_config
self.enable_attention_dp = enable_attention_dp
self.mapping = mapping
self.dist = dist
self.kv_cache_manager_key = kv_cache_manager_key

# --- Internal state ---
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
self.graph_outputs: Dict[Tuple[int, int],
Callable[[], Optional[torch.Tensor]]] = {}
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
self.memory_pool = engine._cuda_graph_mem_pool
self.memory_pool = cuda_graph_mem_pool
self.padding_dummy_request: Optional["Request"] = None

self.shared_static_tensors: Dict[str, torch.Tensor] = {}
if self.enabled:
self._create_shared_static_tensors()
# The max draft length is needed for sizing tensors, but we can't
# know the runtime-enabled draft length here.
# We size for the maximum possible configured draft length.
max_possible_draft_len = self.spec_config.max_draft_len if self.spec_config else 0
self._create_shared_static_tensors(max_possible_draft_len)

def _create_shared_static_tensors(self):
def _create_shared_static_tensors(self, max_draft_len: int):
"""Allocates static tensors sized for the largest possible batch."""
engine = self._get_engine()

token_per_request = self.draft_len + 1
token_per_request = max_draft_len + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
max_total_tokens = min(max_total_tokens, self.max_num_tokens)

self.shared_static_tensors = {
"input_ids":
Expand All @@ -68,7 +94,7 @@ def _create_shared_static_tensors(self):
torch.zeros((1, max_total_tokens), device="cuda",
dtype=torch.int32),
}
if engine.use_mrope:
if self.use_mrope:
self.shared_static_tensors["position_ids"] = torch.zeros(
(3, 1, max_total_tokens), device="cuda", dtype=torch.int32)
self.shared_static_tensors["multimodal_params"] = [
Expand All @@ -82,38 +108,18 @@ def _create_shared_static_tensors(self):
}) for _ in range(max_total_tokens)
]

@property
def enable_spec_decode(self):
return self._get_engine().enable_spec_decode

@property
def draft_len(self):
return self.spec_config.max_draft_len if self.enable_spec_decode else 0

@property
def spec_metadata(self):
return self._get_engine().spec_metadata

@property
def draft_tokens_cuda(self):
return self._get_engine().draft_tokens_cuda

@property
def attn_metadata(self):
return self._get_engine().attn_metadata

def __del__(self):
self.clear()

def _get_engine(self) -> "PyTorchModelEngine":
"""Safely dereferences the weak reference to the engine."""
engine = self.engine_ref()
if engine is None:
raise RuntimeError(
"The parent PyTorchModelEngine has been garbage collected.")
return engine

def maybe_get_cuda_graph(self, batch: ScheduledRequests):
def maybe_get_cuda_graph(
self,
batch: ScheduledRequests,
iter_counter: int,
is_spec_decode: bool,
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata],
draft_tokens_cuda: torch.Tensor,
) -> Tuple[bool, Optional[AttentionMetadata], Optional[SpecMetadata]]:
"""
Determines if the current batch can be run with a CUDA graph.

Expand All @@ -122,17 +128,15 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
- The attn_metadata for the graph, if applicable.
- The spec_metadata for the graph, if applicable.
"""
engine = self._get_engine()

# disable when doing statistic
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
engine.iter_counter):
if hasattr(self,
'iter_counter') and ExpertStatistic.set_iter(iter_counter):
return False, None, None

can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = batch.batch_size
if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
all_can_graph_batch = engine.dist.tp_allgather(
if self.enabled and self.enable_attention_dp and self.mapping.tp_size > 1:
all_can_graph_batch = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
is_all_gen_only = all(all_can_graph[0]
for all_can_graph in all_can_graph_batch)
Expand All @@ -146,7 +150,9 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
if not self.enabled or not can_run_cuda_graph:
return False, None, None

key = (batch_size, self.draft_len)
draft_len = self.spec_config.max_draft_len if is_spec_decode else 0
key = (batch_size, draft_len)

if key in self.graphs:
return True, self.graph_metadata[key][
"attn_metadata"], self.graph_metadata[key]["spec_metadata"]
Expand All @@ -155,33 +161,36 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
return False, None, None

num_sequences_in_batch = batch_size * self.max_beam_width
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
num_sequences_in_batch, False, self.draft_len)
assert attn_metadata.is_cuda_graph
graph_attn_metadata = attn_metadata.create_cuda_graph_metadata(
num_sequences_in_batch, False, draft_len)
assert graph_attn_metadata.is_cuda_graph

if self.enable_spec_decode:
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
graph_spec_metadata = None
if is_spec_decode and spec_metadata:
graph_spec_metadata = spec_metadata.create_cuda_graph_metadata(
num_sequences_in_batch)
spec_metadata.draft_tokens = self.draft_tokens_cuda
else:
spec_metadata = None
return True, attn_metadata, spec_metadata
graph_spec_metadata.draft_tokens = draft_tokens_cuda

return True, graph_attn_metadata, graph_spec_metadata

def needs_capture(self, batch_size: int):
return (batch_size, self.draft_len) not in self.graph_outputs
def needs_capture(self, batch_size: int, is_spec_decode: bool) -> bool:
draft_len = self.spec_config.max_draft_len if is_spec_decode else 0
return (batch_size, draft_len) not in self.graph_outputs

def capture(self,
batch_size: int,
is_spec_decode: bool,
forward_fn: Callable,
initial_inputs: Dict[str, Any],
postprocess_fn: Optional[Callable] = None):
"""Captures the forward pass for a given batch size."""
engine = self._get_engine()
key = (batch_size, self.draft_len)
draft_len = self.spec_config.max_draft_len if is_spec_decode else 0
key = (batch_size, draft_len)

# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = self.draft_len + 1
token_per_request = draft_len + 1
num_tokens_for_capture = (batch_size * self.max_beam_width *
token_per_request)

Expand All @@ -192,7 +201,7 @@ def capture(self,
self.shared_static_tensors["position_ids"]
[:, :num_tokens_for_capture],
}
if engine.use_mrope:
if self.use_mrope:
sliced_static_tensors["position_ids"] = self.shared_static_tensors[
"position_ids"][:, :, :num_tokens_for_capture],
sliced_static_tensors[
Expand Down Expand Up @@ -226,11 +235,12 @@ def capture(self,
self.graph_outputs[key] = make_weak_ref(output)
self.memory_pool = graph.pool()

def replay(self, batch_size: int,
def replay(self, batch_size: int, is_spec_decode: bool,
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
"""Replays a previously captured graph."""
engine = self._get_engine()
key = (batch_size, self.draft_len)
draft_len = self.spec_config.max_draft_len if is_spec_decode else 0
key = (batch_size, draft_len)

stored_meta = self.graph_metadata[key]
assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"]
if stored_meta["spec_metadata"] is not None:
Expand All @@ -244,12 +254,13 @@ def replay(self, batch_size: int,
static_tensors["input_ids"][:seqlen].copy_(input_ids)

position_ids = current_inputs["position_ids"]
if engine.use_mrope and current_inputs.get(
if self.use_mrope and current_inputs.get(
'multimodal_params') is not None:
static_tensors["position_ids"][:, :, :seqlen].copy_(position_ids)
for i, multimodal_param in enumerate(
current_inputs['multimodal_params']):
# NOTE: Currently, we only need 'mrope_position_deltas' on generation phase for multimodal models.
# NOTE: Only 'mrope_position_deltas' is needed on generation
# for multimodal models with CUDA graphs.
static_tensors['multimodal_params'][i].multimodal_data[
'mrope_config']['mrope_position_deltas'].copy_(
multimodal_param.multimodal_data['mrope_config']
Expand All @@ -265,15 +276,14 @@ def replay(self, batch_size: int,

def _get_padded_batch(self, batch: ScheduledRequests,
resource_manager: ResourceManager) -> int:
engine = self._get_engine()
kv_cache_manager = resource_manager.get_resource_manager(
engine.kv_cache_manager_key)
self.kv_cache_manager_key)
can_run_cuda_graph = batch.can_run_cuda_graph
batch_size = batch.batch_size
new_batch_size = batch_size

if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1:
graph_batch_size = engine.dist.tp_allgather(
if self.enabled and self.enable_attention_dp and self.mapping.tp_size > 1:
graph_batch_size = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
all_can_graph = all(graph_batch[0]
for graph_batch in graph_batch_size)
Expand All @@ -291,7 +301,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
return 0

padding_size = padded_batch_size - batch_size
if padding_size + batch.batch_size > engine.batch_size:
if padding_size + batch.batch_size > self.max_batch_size:
return 0

# No padding if it would create too many concurrent requests.
Expand All @@ -306,9 +316,9 @@ def _get_padded_batch(self, batch: ScheduledRequests,
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
is_gen=True,
max_num_draft_tokens=engine.max_draft_len,
use_mrope=engine.use_mrope,
max_beam_width=engine.max_beam_width)[0]
max_num_draft_tokens=self.max_draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width)[0]
self.padding_dummy_request.is_cuda_graph_dummy = True
spec_res_mgr = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
Expand Down
Loading
Loading