diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index f00ab47eeb5..c0bfa1cbb36 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -1,20 +1,21 @@ import bisect import contextlib -import weakref -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch +from tensorrt_llm.mapping import Mapping + from ...inputs.multimodal import MultimodalParams +from ..attention_backend.interface import AttentionMetadata +from ..distributed import MPIDist from ..expert_statistic import ExpertStatistic from ..modules.multi_stream_utils import with_multi_stream +from ..speculative import SpecMetadata from ..utils import make_weak_ref, piecewise_cuda_graph from .resource_manager import ResourceManager, ResourceManagerType from .scheduler import ScheduledRequests -if TYPE_CHECKING: - from .model_engine import PyTorchModelEngine - # A large prime number used for dummy request IDs to avoid collisions CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1 @@ -29,37 +30,62 @@ class CUDAGraphRunner: """ WARMUP_STEPS = 2 - def __init__(self, engine: "PyTorchModelEngine"): - self.engine_ref = weakref.ref(engine) - - # High-level configuration - config = engine.pytorch_backend_config - self.enabled = config.use_cuda_graph - self.padding_enabled = config.cuda_graph_padding_enabled - self.supported_batch_sizes = engine._cuda_graph_batch_sizes - self.max_supported_batch_size = engine._max_cuda_graph_batch_size - self.max_beam_width = engine.max_beam_width - self.spec_config = engine.spec_config - + def __init__( + self, + *, + use_cuda_graph: bool, + cuda_graph_padding_enabled: bool, + supported_batch_sizes: list[int], + max_supported_batch_size: int, + max_batch_size: int, + max_beam_width: int, + max_draft_len: int, + max_num_tokens: int, + use_mrope: bool, + spec_config: Optional["DecodingBaseConfig"], + cuda_graph_mem_pool: Optional[int], + enable_attention_dp: bool, + mapping: Mapping, + dist: Optional[MPIDist], + kv_cache_manager_key: ResourceManagerType, + ): + # --- High-level configuration passed from the engine --- + self.enabled = use_cuda_graph + self.padding_enabled = cuda_graph_padding_enabled + self.supported_batch_sizes = supported_batch_sizes + self.max_supported_batch_size = max_supported_batch_size + self.max_batch_size = max_batch_size + self.max_beam_width = max_beam_width + self.max_draft_len = max_draft_len + self.max_num_tokens = max_num_tokens + self.use_mrope = use_mrope + self.spec_config = spec_config + self.enable_attention_dp = enable_attention_dp + self.mapping = mapping + self.dist = dist + self.kv_cache_manager_key = kv_cache_manager_key + + # --- Internal state --- self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {} self.graph_outputs: Dict[Tuple[int, int], Callable[[], Optional[torch.Tensor]]] = {} self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {} - self.memory_pool = engine._cuda_graph_mem_pool + self.memory_pool = cuda_graph_mem_pool self.padding_dummy_request: Optional["Request"] = None - self.shared_static_tensors: Dict[str, torch.Tensor] = {} if self.enabled: - self._create_shared_static_tensors() + # The max draft length is needed for sizing tensors, but we can't + # know the runtime-enabled draft length here. + # We size for the maximum possible configured draft length. + max_possible_draft_len = self.spec_config.max_draft_len if self.spec_config else 0 + self._create_shared_static_tensors(max_possible_draft_len) - def _create_shared_static_tensors(self): + def _create_shared_static_tensors(self, max_draft_len: int): """Allocates static tensors sized for the largest possible batch.""" - engine = self._get_engine() - - token_per_request = self.draft_len + 1 + token_per_request = max_draft_len + 1 max_total_tokens = (self.max_supported_batch_size * self.max_beam_width * token_per_request) - max_total_tokens = min(max_total_tokens, engine.max_num_tokens) + max_total_tokens = min(max_total_tokens, self.max_num_tokens) self.shared_static_tensors = { "input_ids": @@ -68,7 +94,7 @@ def _create_shared_static_tensors(self): torch.zeros((1, max_total_tokens), device="cuda", dtype=torch.int32), } - if engine.use_mrope: + if self.use_mrope: self.shared_static_tensors["position_ids"] = torch.zeros( (3, 1, max_total_tokens), device="cuda", dtype=torch.int32) self.shared_static_tensors["multimodal_params"] = [ @@ -82,38 +108,18 @@ def _create_shared_static_tensors(self): }) for _ in range(max_total_tokens) ] - @property - def enable_spec_decode(self): - return self._get_engine().enable_spec_decode - - @property - def draft_len(self): - return self.spec_config.max_draft_len if self.enable_spec_decode else 0 - - @property - def spec_metadata(self): - return self._get_engine().spec_metadata - - @property - def draft_tokens_cuda(self): - return self._get_engine().draft_tokens_cuda - - @property - def attn_metadata(self): - return self._get_engine().attn_metadata - def __del__(self): self.clear() - def _get_engine(self) -> "PyTorchModelEngine": - """Safely dereferences the weak reference to the engine.""" - engine = self.engine_ref() - if engine is None: - raise RuntimeError( - "The parent PyTorchModelEngine has been garbage collected.") - return engine - - def maybe_get_cuda_graph(self, batch: ScheduledRequests): + def maybe_get_cuda_graph( + self, + batch: ScheduledRequests, + iter_counter: int, + is_spec_decode: bool, + attn_metadata: AttentionMetadata, + spec_metadata: Optional[SpecMetadata], + draft_tokens_cuda: torch.Tensor, + ) -> Tuple[bool, Optional[AttentionMetadata], Optional[SpecMetadata]]: """ Determines if the current batch can be run with a CUDA graph. @@ -122,17 +128,15 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests): - The attn_metadata for the graph, if applicable. - The spec_metadata for the graph, if applicable. """ - engine = self._get_engine() - # disable when doing statistic - if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter( - engine.iter_counter): + if hasattr(self, + 'iter_counter') and ExpertStatistic.set_iter(iter_counter): return False, None, None can_run_cuda_graph = batch.can_run_cuda_graph batch_size = batch.batch_size - if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1: - all_can_graph_batch = engine.dist.tp_allgather( + if self.enabled and self.enable_attention_dp and self.mapping.tp_size > 1: + all_can_graph_batch = self.dist.tp_allgather( [can_run_cuda_graph, batch_size]) is_all_gen_only = all(all_can_graph[0] for all_can_graph in all_can_graph_batch) @@ -146,7 +150,9 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests): if not self.enabled or not can_run_cuda_graph: return False, None, None - key = (batch_size, self.draft_len) + draft_len = self.spec_config.max_draft_len if is_spec_decode else 0 + key = (batch_size, draft_len) + if key in self.graphs: return True, self.graph_metadata[key][ "attn_metadata"], self.graph_metadata[key]["spec_metadata"] @@ -155,33 +161,36 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests): return False, None, None num_sequences_in_batch = batch_size * self.max_beam_width - attn_metadata = self.attn_metadata.create_cuda_graph_metadata( - num_sequences_in_batch, False, self.draft_len) - assert attn_metadata.is_cuda_graph + graph_attn_metadata = attn_metadata.create_cuda_graph_metadata( + num_sequences_in_batch, False, draft_len) + assert graph_attn_metadata.is_cuda_graph - if self.enable_spec_decode: - spec_metadata = self.spec_metadata.create_cuda_graph_metadata( + graph_spec_metadata = None + if is_spec_decode and spec_metadata: + graph_spec_metadata = spec_metadata.create_cuda_graph_metadata( num_sequences_in_batch) - spec_metadata.draft_tokens = self.draft_tokens_cuda - else: - spec_metadata = None - return True, attn_metadata, spec_metadata + graph_spec_metadata.draft_tokens = draft_tokens_cuda + + return True, graph_attn_metadata, graph_spec_metadata - def needs_capture(self, batch_size: int): - return (batch_size, self.draft_len) not in self.graph_outputs + def needs_capture(self, batch_size: int, is_spec_decode: bool) -> bool: + draft_len = self.spec_config.max_draft_len if is_spec_decode else 0 + return (batch_size, draft_len) not in self.graph_outputs def capture(self, batch_size: int, + is_spec_decode: bool, forward_fn: Callable, initial_inputs: Dict[str, Any], postprocess_fn: Optional[Callable] = None): """Captures the forward pass for a given batch size.""" - engine = self._get_engine() - key = (batch_size, self.draft_len) + draft_len = self.spec_config.max_draft_len if is_spec_decode else 0 + key = (batch_size, draft_len) + # [CUDA graph spec decode padding] # We pad input IDs/position IDs to the maximum draft length (token per request). # We're forced to do this because we cannot reallocate inputs over many graph runs. - token_per_request = self.draft_len + 1 + token_per_request = draft_len + 1 num_tokens_for_capture = (batch_size * self.max_beam_width * token_per_request) @@ -192,7 +201,7 @@ def capture(self, self.shared_static_tensors["position_ids"] [:, :num_tokens_for_capture], } - if engine.use_mrope: + if self.use_mrope: sliced_static_tensors["position_ids"] = self.shared_static_tensors[ "position_ids"][:, :, :num_tokens_for_capture], sliced_static_tensors[ @@ -226,11 +235,12 @@ def capture(self, self.graph_outputs[key] = make_weak_ref(output) self.memory_pool = graph.pool() - def replay(self, batch_size: int, + def replay(self, batch_size: int, is_spec_decode: bool, current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: """Replays a previously captured graph.""" - engine = self._get_engine() - key = (batch_size, self.draft_len) + draft_len = self.spec_config.max_draft_len if is_spec_decode else 0 + key = (batch_size, draft_len) + stored_meta = self.graph_metadata[key] assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"] if stored_meta["spec_metadata"] is not None: @@ -244,12 +254,13 @@ def replay(self, batch_size: int, static_tensors["input_ids"][:seqlen].copy_(input_ids) position_ids = current_inputs["position_ids"] - if engine.use_mrope and current_inputs.get( + if self.use_mrope and current_inputs.get( 'multimodal_params') is not None: static_tensors["position_ids"][:, :, :seqlen].copy_(position_ids) for i, multimodal_param in enumerate( current_inputs['multimodal_params']): - # NOTE: Currently, we only need 'mrope_position_deltas' on generation phase for multimodal models. + # NOTE: Only 'mrope_position_deltas' is needed on generation + # for multimodal models with CUDA graphs. static_tensors['multimodal_params'][i].multimodal_data[ 'mrope_config']['mrope_position_deltas'].copy_( multimodal_param.multimodal_data['mrope_config'] @@ -265,15 +276,14 @@ def replay(self, batch_size: int, def _get_padded_batch(self, batch: ScheduledRequests, resource_manager: ResourceManager) -> int: - engine = self._get_engine() kv_cache_manager = resource_manager.get_resource_manager( - engine.kv_cache_manager_key) + self.kv_cache_manager_key) can_run_cuda_graph = batch.can_run_cuda_graph batch_size = batch.batch_size new_batch_size = batch_size - if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1: - graph_batch_size = engine.dist.tp_allgather( + if self.enabled and self.enable_attention_dp and self.mapping.tp_size > 1: + graph_batch_size = self.dist.tp_allgather( [can_run_cuda_graph, batch_size]) all_can_graph = all(graph_batch[0] for graph_batch in graph_batch_size) @@ -291,7 +301,7 @@ def _get_padded_batch(self, batch: ScheduledRequests, return 0 padding_size = padded_batch_size - batch_size - if padding_size + batch.batch_size > engine.batch_size: + if padding_size + batch.batch_size > self.max_batch_size: return 0 # No padding if it would create too many concurrent requests. @@ -306,9 +316,9 @@ def _get_padded_batch(self, batch: ScheduledRequests, self.padding_dummy_request = kv_cache_manager.add_dummy_requests( [CUDA_GRAPH_DUMMY_REQUEST_ID], is_gen=True, - max_num_draft_tokens=engine.max_draft_len, - use_mrope=engine.use_mrope, - max_beam_width=engine.max_beam_width)[0] + max_num_draft_tokens=self.max_draft_len, + use_mrope=self.use_mrope, + max_beam_width=self.max_beam_width)[0] self.padding_dummy_request.is_cuda_graph_dummy = True spec_res_mgr = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 115bd2ce393..0752c53da5b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -256,6 +256,7 @@ def __init__( self.attn_backend = get_attention_backend( pytorch_backend_config.attn_backend) + self.draft_tokens_cuda = None if self.is_spec_decode: self.spec_metadata = None @@ -327,7 +328,22 @@ def __init__( # with different KV cache managers. self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER self.lora_model_config: Optional[LoraModelConfig] = None - self.cuda_graph_runner = CUDAGraphRunner(self) + self.cuda_graph_runner = CUDAGraphRunner( + use_cuda_graph=pytorch_backend_config.use_cuda_graph, + cuda_graph_padding_enabled=self._cuda_graph_padding_enabled, + supported_batch_sizes=self._cuda_graph_batch_sizes, + max_supported_batch_size=self._max_cuda_graph_batch_size, + max_batch_size=self.batch_size, + max_beam_width=self.max_beam_width, + max_draft_len=self.max_draft_len, + max_num_tokens=self.max_num_tokens, + use_mrope=self.use_mrope, + spec_config=self.spec_config, + cuda_graph_mem_pool=self._cuda_graph_mem_pool, + enable_attention_dp=self.enable_attention_dp, + mapping=self.mapping, + dist=self.dist, + kv_cache_manager_key=self.kv_cache_manager_key) # Setup the local cache indirection buffer only once and reuse it. # This way it can also be used for CUDA graphs. @@ -2095,7 +2111,13 @@ def forward( scheduled_requests, resource_manager) as padded_requests: maybe_graph, maybe_attn_metadata, maybe_spec_metadata = self.cuda_graph_runner.maybe_get_cuda_graph( - padded_requests) + padded_requests, + iter_counter=self.iter_counter, + is_spec_decode=self.enable_spec_decode, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + draft_tokens_cuda=self.draft_tokens_cuda, + ) if maybe_graph: attn_metadata = maybe_attn_metadata spec_metadata = maybe_spec_metadata @@ -2119,7 +2141,8 @@ def forward( gather_context_logits) else: batch_size = len(padded_requests.generation_requests) - if self.cuda_graph_runner.needs_capture(batch_size): + if self.cuda_graph_runner.needs_capture( + batch_size, self.enable_spec_decode): def capture_forward_fn(inputs: Dict[str, Any]): with MoeLoadBalancerIterContext(moe_load_balancer): @@ -2132,16 +2155,18 @@ def capture_postprocess_fn(inputs: Dict[str, Any]): self._postprocess_inputs(inputs) self.cuda_graph_runner.capture(batch_size, + self.enable_spec_decode, capture_forward_fn, inputs, capture_postprocess_fn) # here we don't need to use context since cuda graph capture didn't run kernel. # maybe we need a cleaner way to do this. - outputs = self.cuda_graph_runner.replay(batch_size, inputs) + outputs = self.cuda_graph_runner.replay( + batch_size, self.enable_spec_decode, inputs) else: with MoeLoadBalancerIterContext(moe_load_balancer): outputs = self.cuda_graph_runner.replay( - batch_size, inputs) + batch_size, self.enable_spec_decode, inputs) self._execute_logit_post_processors(scheduled_requests, outputs) diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 18a30242faa..b961477abf9 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -3,6 +3,10 @@ import torch import torch.nn.functional as F +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType +from tensorrt_llm.mapping import Mapping + def ceil_div(x: int, y: int) -> int: return (x + y - 1) // y @@ -164,32 +168,20 @@ def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor, return results.view_as(x) -class MockPytorchBackendConfig: - - def __init__(self, use_cuda_graph, cuda_graph_padding_enabled): - self.use_cuda_graph = use_cuda_graph - self.cuda_graph_padding_enabled = cuda_graph_padding_enabled - - -class MockEngine: - """A replacement for SimpleNamespace that supports weak references.""" - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - -def create_mock_engine(batch_size: int): - - return MockEngine( - pytorch_backend_config=MockPytorchBackendConfig( - use_cuda_graph=True, cuda_graph_padding_enabled=False), - _cuda_graph_batch_sizes=[batch_size], - _max_cuda_graph_batch_size=batch_size, +def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False): + return CUDAGraphRunner( + use_cuda_graph=True, + cuda_graph_padding_enabled=False, + supported_batch_sizes=[batch_size], + max_supported_batch_size=batch_size, + max_batch_size=batch_size, max_beam_width=1, - max_num_tokens=8192, - is_spec_decode=False, - enable_spec_decode=False, + max_draft_len=0, + max_num_tokens=1, + use_mrope=use_mrope, spec_config=None, - _cuda_graph_mem_pool=None, - use_mrope=False, - ) + cuda_graph_mem_pool=None, + enable_attention_dp=False, + mapping=Mapping(), + dist=None, + kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER) diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4.py b/tests/unittest/_torch/modeling/test_modeling_exaone4.py index ebf496b2c14..2069936cbce 100644 --- a/tests/unittest/_torch/modeling/test_modeling_exaone4.py +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4.py @@ -22,7 +22,7 @@ class Exaone4Config(PretrainedConfig): # TODO: Remove this once we have a proper config for Exaone4 SKIP_EXAONE4_HF_ACCURACY_TEST = True -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from transformers.cache_utils import HybridCache from utils.util import getSMVersion @@ -31,7 +31,6 @@ class Exaone4Config(PretrainedConfig): from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -338,10 +337,8 @@ def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -355,7 +352,7 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, + graph_runner.capture(1, False, lambda inputs: exaone4.forward(**inputs), inputs) @@ -363,7 +360,7 @@ def run_forward(input_ids, position_ids, attn_metadata): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index 73cd4bf9bac..c9a1e3d8a2c 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -4,7 +4,7 @@ from typing import Any import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import LlamaConfig from transformers import LlamaForCausalLM as HFLlamaForCausalLM @@ -15,7 +15,6 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -326,10 +325,8 @@ def test_llama_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -343,13 +340,14 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, lambda inputs: llama.forward(**inputs), + graph_runner.capture(1, False, + lambda inputs: llama.forward(**inputs), inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 2f7618cb39b..4ebcb204b0f 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -4,7 +4,7 @@ import torch import transformers -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import Llama4Config from transformers import \ @@ -21,7 +21,6 @@ from tensorrt_llm._torch.models.modeling_llama import \ Llama4ForConditionalGeneration from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -403,10 +402,9 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: input_ids.size(-1) + gen_input_ids.size(-1)) ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -420,14 +418,15 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, lambda inputs: llama.forward(**inputs), + graph_runner.capture(1, False, + lambda inputs: llama.forward(**inputs), inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_mistral.py b/tests/unittest/_torch/modeling/test_modeling_mistral.py index 94f3226c64e..383fe7d4a16 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mistral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mistral.py @@ -8,7 +8,7 @@ import torch import transformers import transformers.models.mistral3 -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from PIL import Image from utils.util import getSMVersion @@ -19,7 +19,6 @@ from tensorrt_llm._torch.attention_backend import utils as attention_utils from tensorrt_llm._torch.models import modeling_mistral from tensorrt_llm._torch.pyexecutor import resource_manager -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm.bindings import executor as executor_lib from tensorrt_llm.models import modeling_utils @@ -404,10 +403,7 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner(1) if use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -421,13 +417,13 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, lambda inputs: mistral.forward(**inputs), inputs) + graph_runner.capture(1, False, lambda inputs: mistral.forward(**inputs), inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_mixtral.py b/tests/unittest/_torch/modeling/test_modeling_mixtral.py index 1637120b304..2fe2fdd3ce8 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mixtral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mixtral.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import MixtralConfig from transformers import MixtralForCausalLM as HFMixtralForCausalLM @@ -16,7 +16,6 @@ from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \ MixtralHfWeightMapper from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -310,10 +309,8 @@ def test_mixtral_allclose_to_hf(self, scenario: Scenario): ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -327,7 +324,7 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, + graph_runner.capture(1, False, lambda inputs: mixtral.forward(**inputs), inputs) @@ -335,7 +332,7 @@ def run_forward(input_ids, position_ids, attn_metadata): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_mllama.py b/tests/unittest/_torch/modeling/test_modeling_mllama.py index 72f5287c6a6..e7e50b655ad 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mllama.py +++ b/tests/unittest/_torch/modeling/test_modeling_mllama.py @@ -4,7 +4,7 @@ import pytest import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from test_modeling_llama import Scenario, reduce_llama_config from transformers import MllamaConfig @@ -17,7 +17,6 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_mllama import \ MllamaForConditionalGeneration -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -420,10 +419,8 @@ def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -437,14 +434,15 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, lambda inputs: mllama.forward(**inputs), + graph_runner.capture(1, False, + lambda inputs: mllama.forward(**inputs), inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron.py b/tests/unittest/_torch/modeling/test_modeling_nemotron.py index 11456d0f099..3811448b647 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron.py @@ -4,7 +4,7 @@ from typing import Any import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import NemotronConfig from transformers import NemotronForCausalLM as HFNemotronForCausalLM @@ -15,7 +15,6 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_nemotron import NemotronForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -318,10 +317,8 @@ def test_nemotron_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -335,7 +332,7 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, + graph_runner.capture(1, False, lambda inputs: nemotron.forward(**inputs), inputs) @@ -343,7 +340,7 @@ def run_forward(input_ids, position_ids, attn_metadata): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_phi3.py b/tests/unittest/_torch/modeling/test_modeling_phi3.py index 7c5ffd94141..d9adad5b64a 100644 --- a/tests/unittest/_torch/modeling/test_modeling_phi3.py +++ b/tests/unittest/_torch/modeling/test_modeling_phi3.py @@ -4,7 +4,7 @@ from typing import Any import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from transformers import Phi3Config from transformers import Phi3ForCausalLM as HFPhi3ForCausalLM from utils.util import default_dtype @@ -14,7 +14,6 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_phi3 import Phi3ForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -310,10 +309,8 @@ def test_phi3_allclose_to_hf(self) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -327,14 +324,15 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, lambda inputs: phi3.forward(**inputs), + graph_runner.capture(1, False, + lambda inputs: phi3.forward(**inputs), inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen.py b/tests/unittest/_torch/modeling/test_modeling_qwen.py index d1d129de083..50d32ca0caa 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen.py @@ -17,12 +17,11 @@ from tensorrt_llm._torch.models.modeling_qwen import ( Qwen2ForCausalLM, Qwen2ForProcessRewardModel) # yapf: enable -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from utils.llm_data import llm_models_root from utils.util import getSMVersion @@ -265,10 +264,8 @@ def test_qwen_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -282,14 +279,15 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, lambda inputs: qwen.forward(**inputs), + graph_runner.capture(1, False, + lambda inputs: qwen.forward(**inputs), inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py b/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py index dd1ff90aa4c..a2097b02d64 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen2_5vl.py @@ -5,7 +5,7 @@ from typing import List import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import AutoProcessor, AutoTokenizer, Qwen2_5_VLConfig from transformers import \ @@ -18,7 +18,6 @@ from tensorrt_llm._torch.models.checkpoints.hf.qwen2vl_weight_mapper import \ Qwen2VLHfWeightMapper from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2_5_VLModel -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.inputs import (create_input_processor, @@ -477,11 +476,8 @@ def test_qwen2_5_vl_allclose_to_hf(self, scenario: Scenario) -> None: target_keywords=["mrope_config.mrope_position_deltas"]) gen_multimodal_params_list.append(multimodal_param) - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - mock_engine.use_mrope = True - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1, True) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata, multimodal_params): @@ -500,6 +496,7 @@ def run_forward(input_ids, position_ids, attn_metadata, } graph_runner.capture( batch_size=1, + is_spec_decode=False, forward_fn=lambda inputs: qwen2_5_vl.forward(**inputs), initial_inputs=inputs) @@ -508,6 +505,7 @@ def run_forward(input_ids, position_ids, attn_metadata, # in prepare(). attn_metadata.prepare() logits = graph_runner.replay(batch_size=1, + is_spec_decode=False, current_inputs=inputs) return logits diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py index 8658ae0e242..e57a8a1e60e 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch -from _torch.helpers import create_mock_engine +from _torch.helpers import create_mock_cuda_graph_runner from parameterized import parameterized from transformers import Qwen2MoeConfig from transformers import Qwen2MoeForCausalLM as HFQwen2MoeForCausalLM @@ -16,7 +16,6 @@ from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \ Qwen2MoeHfWeightMapper from tensorrt_llm._torch.models.modeling_qwen_moe import Qwen2MoeForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -315,10 +314,8 @@ def test_qwen_moe_allclose_to_hf(self, scenario: Scenario): ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() - graph_runner = None - if scenario.use_cuda_graph: - mock_engine = create_mock_engine(1) - graph_runner = CUDAGraphRunner(mock_engine) + graph_runner = create_mock_cuda_graph_runner( + 1) if scenario.use_cuda_graph else None def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -332,7 +329,7 @@ def run_forward(input_ids, position_ids, attn_metadata): "position_ids": position_ids, "attn_metadata": attn_metadata, } - graph_runner.capture(1, + graph_runner.capture(1, False, lambda inputs: qwen_moe.forward(**inputs), inputs) @@ -340,7 +337,7 @@ def run_forward(input_ids, position_ids, attn_metadata): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.replay(1, inputs) + logits = graph_runner.replay(1, False, inputs) return logits if scenario.use_cuda_graph: