Skip to content

Commit f34b473

Browse files
syed-ahmedandrewor14
authored andcommitted
Uses torch.version.cuda to compile CUDA extensions (#2193)
* Uses torch.version.cuda to compile CUDA extensions * lint
1 parent 9b22da8 commit f34b473

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

setup.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,17 +255,15 @@ def get_extensions():
255255
print(
256256
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
257257
)
258-
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
258+
if (CUDA_HOME is None and ROCM_HOME is None) and torch.version.cuda:
259259
print(
260260
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
261261
)
262262
print(
263263
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
264264
)
265265

266-
use_cuda = torch.cuda.is_available() and (
267-
CUDA_HOME is not None or ROCM_HOME is not None
268-
)
266+
use_cuda = torch.version.cuda and (CUDA_HOME is not None or ROCM_HOME is not None)
269267
extension = CUDAExtension if use_cuda else CppExtension
270268

271269
extra_link_args = []

0 commit comments

Comments
 (0)