File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -255,17 +255,15 @@ def get_extensions():
255
255
print (
256
256
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
257
257
)
258
- if (CUDA_HOME is None and ROCM_HOME is None ) and torch .cuda . is_available () :
258
+ if (CUDA_HOME is None and ROCM_HOME is None ) and torch .version . cuda :
259
259
print (
260
260
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
261
261
)
262
262
print (
263
263
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
264
264
)
265
265
266
- use_cuda = torch .cuda .is_available () and (
267
- CUDA_HOME is not None or ROCM_HOME is not None
268
- )
266
+ use_cuda = torch .version .cuda and (CUDA_HOME is not None or ROCM_HOME is not None )
269
267
extension = CUDAExtension if use_cuda else CppExtension
270
268
271
269
extra_link_args = []
You can’t perform that action at this time.
0 commit comments