@@ -1008,87 +1008,51 @@ def args_bounds_check(
1008
1008
return args [i ] if len (args ) > i and args [i ] is not None else replacement
1009
1009
1010
1010
1011
- def install_wget (platform : str ) -> None :
1012
- if shutil .which ("wget" ):
1013
- _LOGGER .debug ("wget is already installed" )
1014
- return
1015
- if platform .startswith ("linux" ):
1016
- try :
1017
- # if its root
1018
- if os .geteuid () == 0 :
1019
- subprocess .run (["apt-get" , "update" ], check = True )
1020
- subprocess .run (["apt-get" , "install" , "-y" , "wget" ], check = True )
1021
- else :
1022
- _LOGGER .debug ("Please run with sudo permissions" )
1023
- subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
1024
- subprocess .run (["sudo" , "apt-get" , "install" , "-y" , "wget" ], check = True )
1025
- except subprocess .CalledProcessError as e :
1026
- _LOGGER .debug ("Error installing wget:" , e )
1027
-
1028
-
1029
- def install_mpi (platform : str ) -> None :
1030
- if platform .startswith ("linux" ):
1031
- try :
1032
- # if its root
1033
- if os .geteuid () == 0 :
1034
- subprocess .run (["apt-get" , "update" ], check = True )
1035
- subprocess .run (["apt-get" , "install" , "-y" , "libmpich-dev" ], check = True )
1036
- subprocess .run (
1037
- ["apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
1038
- )
1039
- else :
1040
- _LOGGER .debug ("Please run with sudo permissions" )
1041
- subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
1042
- subprocess .run (
1043
- ["sudo" , "apt-get" , "install" , "-y" , "libmpich-dev" ], check = True
1044
- )
1045
- subprocess .run (
1046
- ["sudo" , "apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
1047
- )
1048
- except subprocess .CalledProcessError as e :
1049
- _LOGGER .debug ("Error installing mpi libs:" , e )
1050
-
1051
-
1052
1011
def download_plugin_lib_path (py_version : str , platform : str ) -> str :
1053
1012
plugin_lib_path = None
1054
- if py_version not in ("cp310" , "cp312" ):
1055
- _LOGGER .warning (
1056
- "No available wheel for python versions other than py3.10 and py3.12"
1057
- )
1058
- install_wget (platform )
1013
+
1014
+ # Downloading TRT-LLM lib
1015
+ # TODO: check how to fix the 0.18.0 hardcode below
1059
1016
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
1060
- file_name = f"tensorrt_llm-0.17 .0.post1-{ py_version } -{ py_version } -{ platform } .whl"
1017
+ file_name = f"tensorrt_llm-0.18 .0.post1-{ py_version } -{ py_version } -{ platform } .whl"
1061
1018
download_url = base_url + file_name
1062
1019
cmd = ["wget" , download_url ]
1063
- try :
1064
- if not (os .path .exists (file_name )):
1065
- _LOGGER .info (f"Running command: { ' ' .join (cmd )} " )
1066
- subprocess .run (cmd )
1067
- _LOGGER .info ("Download complete of wheel" )
1068
- if os .path .exists (file_name ):
1069
- _LOGGER .info ("filename now present" )
1070
- if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
1071
- plugin_lib_path = (
1072
- "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1073
- )
1074
- else :
1075
- import zipfile
1020
+ if not (os .path .exists (file_name )):
1021
+ try :
1022
+ subprocess .run (cmd , check = True )
1023
+ _LOGGER .debug ("Download succeeded and TRT-LLM wheel is now present" )
1024
+ except subprocess .CalledProcessError as e :
1025
+ _LOGGER .error (
1026
+ "Download failed (file not found or connection issue). Error code:" ,
1027
+ e .returncode ,
1028
+ )
1029
+ except FileNotFoundError :
1030
+ _LOGGER .error ("wget is required but not found. Please install wget." )
1076
1031
1077
- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
1078
- zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
1079
- plugin_lib_path = (
1080
- "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1081
- )
1082
- except subprocess .CalledProcessError as e :
1083
- _LOGGER .debug (f"Error occurred while trying to download: { e } " )
1084
- except Exception as e :
1085
- _LOGGER .debug (f"An unexpected error occurred: { e } " )
1032
+ # Proceeding with the unzip of the wheel file
1033
+ # This will exist if the filename was already downloaded
1034
+ if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
1035
+ plugin_lib_path = "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1036
+ else :
1037
+ try :
1038
+ import zipfile
1039
+ except :
1040
+ raise ImportError (
1041
+ "zipfile module is required but not found. Please install zipfile"
1042
+ )
1043
+ with zipfile .ZipFile (file_name , "r" ) as zip_ref :
1044
+ zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
1045
+ plugin_lib_path = (
1046
+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1047
+ )
1086
1048
return plugin_lib_path
1087
1049
1088
1050
1089
1051
def load_tensorrt_llm () -> bool :
1090
1052
"""
1091
1053
Attempts to load the TensorRT-LLM plugin and initialize it.
1054
+ Either the env variable TRTLLM_PLUGINS_PATH specifies the path
1055
+ If the above is not, the user can specify USE_TRTLLM_PLUGINS as either of 1, true, yes, on to download the TRT-LLM distribution and load it
1092
1056
1093
1057
Returns:
1094
1058
bool: True if the plugin was successfully loaded and initialized, False otherwise.
@@ -1098,8 +1062,9 @@ def load_tensorrt_llm() -> bool:
1098
1062
_LOGGER .warning (
1099
1063
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library" ,
1100
1064
)
1101
- for key , value in os .environ .items ():
1102
- print (f"{ key } : { value } " )
1065
+ # for key, value in os.environ.items():
1066
+ # print(f"{key}: {value}")
1067
+ # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
1103
1068
use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
1104
1069
"1" ,
1105
1070
"true" ,
@@ -1112,14 +1077,14 @@ def load_tensorrt_llm() -> bool:
1112
1077
)
1113
1078
return False
1114
1079
else :
1115
- py_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
1080
+ # this is used as the default py version
1081
+ py_version = f"cp312"
1116
1082
platform = Platform .current_platform ()
1117
1083
1118
1084
platform = str (platform ).lower ()
1119
1085
plugin_lib_path = download_plugin_lib_path (py_version , platform )
1120
1086
try :
1121
- # Load the shared
1122
- install_mpi (platform )
1087
+ # Load the shared TRT-LLM file
1123
1088
handle = ctypes .CDLL (plugin_lib_path )
1124
1089
_LOGGER .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
1125
1090
except OSError as e_os_error :
0 commit comments