diff --git a/pkg/gpu/gpu.go b/pkg/gpu/gpu.go index fd48802f..b1e1e9fd 100644 --- a/pkg/gpu/gpu.go +++ b/pkg/gpu/gpu.go @@ -2,6 +2,7 @@ package gpu import ( "context" + "os/exec" "github.com/docker/docker/client" ) @@ -18,12 +19,30 @@ const ( // ProbeGPUSupport determines whether or not the Docker engine has GPU support. func ProbeGPUSupport(ctx context.Context, dockerClient client.SystemAPIClient) (GPUSupport, error) { - info, err := dockerClient.Info(ctx) + // First search for nvidia-container-runtime on PATH + if _, err := exec.LookPath("nvidia-container-runtime"); err == nil { + return GPUSupportCUDA, nil + } + + // Next look for explicitly configured nvidia runtime. This is not required in Docker 19.03+ but + // may be configured on some systems + hasNvidia, err := HasNVIDIARuntime(ctx, dockerClient) if err != nil { return GPUSupportNone, err } - if _, hasNvidia := info.Runtimes["nvidia"]; hasNvidia { + if hasNvidia { return GPUSupportCUDA, nil } + return GPUSupportNone, nil } + +// HasNVIDIARuntime determines whether there is an nvidia runtime available +func HasNVIDIARuntime(ctx context.Context, dockerClient client.SystemAPIClient) (bool, error) { + info, err := dockerClient.Info(ctx) + if err != nil { + return false, err + } + _, hasNvidia := info.Runtimes["nvidia"] + return hasNvidia, nil +} diff --git a/pkg/standalone/containers.go b/pkg/standalone/containers.go index 78759e30..a5671215 100644 --- a/pkg/standalone/containers.go +++ b/pkg/standalone/containers.go @@ -268,7 +268,9 @@ func CreateControllerContainer(ctx context.Context, dockerClient *client.Client, nat.Port(portStr + "/tcp"): portBindings, } if gpu == gpupkg.GPUSupportCUDA { - hostConfig.Runtime = "nvidia" + if ok, err := gpupkg.HasNVIDIARuntime(ctx, dockerClient); err == nil && ok { + hostConfig.Runtime = "nvidia" + } hostConfig.DeviceRequests = []container.DeviceRequest{{Count: -1, Capabilities: [][]string{{"gpu"}}}} }