Skip to content

add feature gate for tensorrt plugin #3518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

lanluo-nvidia
Copy link
Collaborator

Description

Tensorrt 10.3.0 does not support tensorrt plugin which causing failures.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 12, 2025
@github-actions github-actions bot requested a review from gs-olive May 12, 2025 22:53
import tensorrt.plugin as trtp

assert trtp
_TENSORRT_PLUGIN_AVAIL = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call this QDP_PLUGIN?

)


def _enabled_features_str() -> str:
enabled = lambda x: "ENABLED" if x else "DISABLED"
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call]
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - TensorRT Plugin: {enabled(_TENSORRT_PLUGIN_AVAIL)}\n" # type: ignore[no-untyped-call]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think refit got skipped, can we add that as well>

)
else:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a giant if else, take a look at

@for_all_methods(needs_torch_tensorrt_runtime)
,
def needs_torch_tensorrt_runtime(f: Callable[..., Any]) -> Callable[..., Any]:

supports_dynamic_shapes=supports_dynamic_shapes,
requires_output_allocator=requires_output_allocator,
assert trtp
except ImportError as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@shchoi00
Copy link

shchoi00 commented May 15, 2025

+1

My environment

jetpack 6.2
torch 2.7
tensorrt 10.3
cuda 12.6

$ python3 -c "import torch_tensorrt"
Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
Traceback (most recent call last):
File "", line 1, in
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/init.py", line 125, in
from torch_tensorrt.runtime import * # noqa: F403
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/runtime/init.py", line 1, in
from torch_tensorrt.dynamo.runtime import ( # noqa: F401
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/init.py", line 10, in
from ._compiler import (
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 15, in
from torch_tensorrt.dynamo import _defaults, partitioning
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/init.py", line 1, in
from ._adjacency_partitioner import partition as fast_partition
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py", line 20, in
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/init.py", line 1, in
from . import (
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/plugins/init.py", line 1, in
from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py", line 6, in
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
File "/home/ircv13/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py", line 5, in
import tensorrt.plugin as trtp
ModuleNotFoundError: No module named 'tensorrt.plugin'

@narendasan
Copy link
Collaborator

@shchoi00 yes this is support of Jetson as part of our build system overhaul

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants