Skip to content

Commit c307a3f

Browse files
authored
[1.8] Do not print warning if CUDA driver not found (pytorch#51806) (pytorch#52050)
Summary: It frequently happens when PyTorch compiled with CUDA support is installed on machine that does not have NVIDIA GPUs. Fixes pytorch#47038 Pull Request resolved: pytorch#51806 Reviewed By: ezyang Differential Revision: D26285827 Pulled By: malfet fbshipit-source-id: 9fd5e690d0135a2b219c1afa803fb69de9729f5e
1 parent f071020 commit c307a3f

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

c10/cuda/CUDAFunctions.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int32_t driver_version() {
1616
return driver_version;
1717
}
1818

19-
int device_count_impl() {
19+
int device_count_impl(bool fail_if_no_driver) {
2020
int count;
2121
auto err = cudaGetDeviceCount(&count);
2222
if (err == cudaSuccess) {
@@ -34,6 +34,11 @@ int device_count_impl() {
3434
case cudaErrorInsufficientDriver: {
3535
auto version = driver_version();
3636
if (version <= 0) {
37+
if (!fail_if_no_driver) {
38+
// No CUDA driver means no devices
39+
count = 0;
40+
break;
41+
}
3742
TORCH_CHECK(
3843
false,
3944
"Found no NVIDIA driver on your system. Please check that you "
@@ -95,9 +100,9 @@ DeviceIndex device_count() noexcept {
95100
// initialize number of devices only once
96101
static int count = []() {
97102
try {
98-
auto result = device_count_impl();
103+
auto result = device_count_impl(/*fail_if_no_driver=*/false);
99104
TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
100-
return device_count_impl();
105+
return result;
101106
} catch (const c10::Error& ex) {
102107
// We don't want to fail, but still log the warning
103108
// msg() returns the message without the stack trace
@@ -110,7 +115,7 @@ DeviceIndex device_count() noexcept {
110115

111116
DeviceIndex device_count_ensure_non_zero() {
112117
// Call the implementation every time to throw the exception
113-
int count = device_count_impl();
118+
int count = device_count_impl(/*fail_if_no_driver=*/true);
114119
// Zero gpus doesn't produce a warning in `device_count` but we fail here
115120
TORCH_CHECK(count, "No CUDA GPUs are available");
116121
return static_cast<DeviceIndex>(count);

0 commit comments

Comments
 (0)