Skip to content

Commit c075924

Browse files
committed
changes in the TRT-LLM loading tool- removing install_wget, install_unzip, install_mpi
1 parent b86278f commit c075924

File tree

3 files changed

+39
-167
lines changed

3 files changed

+39
-167
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+39-74
Original file line numberDiff line numberDiff line change
@@ -1008,87 +1008,51 @@ def args_bounds_check(
10081008
return args[i] if len(args) > i and args[i] is not None else replacement
10091009

10101010

1011-
def install_wget(platform: str) -> None:
1012-
if shutil.which("wget"):
1013-
_LOGGER.debug("wget is already installed")
1014-
return
1015-
if platform.startswith("linux"):
1016-
try:
1017-
# if its root
1018-
if os.geteuid() == 0:
1019-
subprocess.run(["apt-get", "update"], check=True)
1020-
subprocess.run(["apt-get", "install", "-y", "wget"], check=True)
1021-
else:
1022-
_LOGGER.debug("Please run with sudo permissions")
1023-
subprocess.run(["sudo", "apt-get", "update"], check=True)
1024-
subprocess.run(["sudo", "apt-get", "install", "-y", "wget"], check=True)
1025-
except subprocess.CalledProcessError as e:
1026-
_LOGGER.debug("Error installing wget:", e)
1027-
1028-
1029-
def install_mpi(platform: str) -> None:
1030-
if platform.startswith("linux"):
1031-
try:
1032-
# if its root
1033-
if os.geteuid() == 0:
1034-
subprocess.run(["apt-get", "update"], check=True)
1035-
subprocess.run(["apt-get", "install", "-y", "libmpich-dev"], check=True)
1036-
subprocess.run(
1037-
["apt-get", "install", "-y", "libopenmpi-dev"], check=True
1038-
)
1039-
else:
1040-
_LOGGER.debug("Please run with sudo permissions")
1041-
subprocess.run(["sudo", "apt-get", "update"], check=True)
1042-
subprocess.run(
1043-
["sudo", "apt-get", "install", "-y", "libmpich-dev"], check=True
1044-
)
1045-
subprocess.run(
1046-
["sudo", "apt-get", "install", "-y", "libopenmpi-dev"], check=True
1047-
)
1048-
except subprocess.CalledProcessError as e:
1049-
_LOGGER.debug("Error installing mpi libs:", e)
1050-
1051-
10521011
def download_plugin_lib_path(py_version: str, platform: str) -> str:
10531012
plugin_lib_path = None
1054-
if py_version not in ("cp310", "cp312"):
1055-
_LOGGER.warning(
1056-
"No available wheel for python versions other than py3.10 and py3.12"
1057-
)
1058-
install_wget(platform)
1013+
1014+
# Downloading TRT-LLM lib
1015+
# TODO: check how to fix the 0.18.0 hardcode below
10591016
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
1060-
file_name = f"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
1017+
file_name = f"tensorrt_llm-0.18.0.post1-{py_version}-{py_version}-{platform}.whl"
10611018
download_url = base_url + file_name
10621019
cmd = ["wget", download_url]
1063-
try:
1064-
if not (os.path.exists(file_name)):
1065-
_LOGGER.info(f"Running command: {' '.join(cmd)}")
1066-
subprocess.run(cmd)
1067-
_LOGGER.info("Download complete of wheel")
1068-
if os.path.exists(file_name):
1069-
_LOGGER.info("filename now present")
1070-
if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"):
1071-
plugin_lib_path = (
1072-
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1073-
)
1074-
else:
1075-
import zipfile
1020+
if not (os.path.exists(file_name)):
1021+
try:
1022+
subprocess.run(cmd, check=True)
1023+
_LOGGER.debug("Download succeeded and TRT-LLM wheel is now present")
1024+
except subprocess.CalledProcessError as e:
1025+
_LOGGER.error(
1026+
"Download failed (file not found or connection issue). Error code:",
1027+
e.returncode,
1028+
)
1029+
except FileNotFoundError:
1030+
_LOGGER.error("wget is required but not found. Please install wget.")
10761031

1077-
with zipfile.ZipFile(file_name, "r") as zip_ref:
1078-
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
1079-
plugin_lib_path = (
1080-
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1081-
)
1082-
except subprocess.CalledProcessError as e:
1083-
_LOGGER.debug(f"Error occurred while trying to download: {e}")
1084-
except Exception as e:
1085-
_LOGGER.debug(f"An unexpected error occurred: {e}")
1032+
# Proceeding with the unzip of the wheel file
1033+
# This will exist if the filename was already downloaded
1034+
if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"):
1035+
plugin_lib_path = "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1036+
else:
1037+
try:
1038+
import zipfile
1039+
except:
1040+
raise ImportError(
1041+
"zipfile module is required but not found. Please install zipfile"
1042+
)
1043+
with zipfile.ZipFile(file_name, "r") as zip_ref:
1044+
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
1045+
plugin_lib_path = (
1046+
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1047+
)
10861048
return plugin_lib_path
10871049

10881050

10891051
def load_tensorrt_llm() -> bool:
10901052
"""
10911053
Attempts to load the TensorRT-LLM plugin and initialize it.
1054+
Either the env variable TRTLLM_PLUGINS_PATH specifies the path
1055+
If the above is not, the user can specify USE_TRTLLM_PLUGINS as either of 1, true, yes, on to download the TRT-LLM distribution and load it
10921056
10931057
Returns:
10941058
bool: True if the plugin was successfully loaded and initialized, False otherwise.
@@ -1098,8 +1062,9 @@ def load_tensorrt_llm() -> bool:
10981062
_LOGGER.warning(
10991063
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library",
11001064
)
1101-
for key, value in os.environ.items():
1102-
print(f"{key}: {value}")
1065+
# for key, value in os.environ.items():
1066+
# print(f"{key}: {value}")
1067+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
11031068
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
11041069
"1",
11051070
"true",
@@ -1112,14 +1077,14 @@ def load_tensorrt_llm() -> bool:
11121077
)
11131078
return False
11141079
else:
1115-
py_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
1080+
# this is used as the default py version
1081+
py_version = f"cp312"
11161082
platform = Platform.current_platform()
11171083

11181084
platform = str(platform).lower()
11191085
plugin_lib_path = download_plugin_lib_path(py_version, platform)
11201086
try:
1121-
# Load the shared
1122-
install_mpi(platform)
1087+
# Load the shared TRT-LLM file
11231088
handle = ctypes.CDLL(plugin_lib_path)
11241089
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
11251090
except OSError as e_os_error:

tests/py/dynamo/conversion/harness.py

-13
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ def generate_graph(
353353
enable_passes: bool,
354354
propagate_shapes: bool = False,
355355
settings: CompilationSettings = CompilationSettings(),
356-
fuse_distributed_ops: bool = False,
357356
torch_export_dynamic_shapes: Optional[Any] = None,
358357
):
359358
mod = mod.eval()
@@ -369,16 +368,6 @@ def generate_graph(
369368
tuple(torch_export_inputs),
370369
dynamic_shapes=torch_export_dynamic_shapes,
371370
)
372-
if fuse_distributed_ops:
373-
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
374-
fuse_distributed_ops,
375-
)
376-
377-
gm = exported_program.graph_module
378-
gm = fuse_distributed_ops(gm, settings)
379-
exported_program = exported_program.run_decompositions(
380-
get_decompositions(False)
381-
)
382371
if enable_passes:
383372
exported_program = pre_export_lowering(exported_program, settings)
384373
exported_program = exported_program.run_decompositions(
@@ -417,7 +406,6 @@ def run_test(
417406
propagate_shapes=False,
418407
int32_reqd=False,
419408
immutable_weights=True,
420-
fuse_distributed_ops=False,
421409
):
422410
# TODO: lan to remove this and set use_dynamo_traccer to True by default
423411
# once all the converter test files are moved to use_dynamo_tracer
@@ -438,7 +426,6 @@ def run_test(
438426
enable_passes=enable_passes,
439427
propagate_shapes=propagate_shapes,
440428
settings=compilation_settings,
441-
fuse_distributed_ops=fuse_distributed_ops,
442429
)
443430

444431
num_inputs = len(inputs)

tests/py/dynamo/conversion/test_nccl_ops.py

-80
This file was deleted.

0 commit comments

Comments
 (0)