diff --git a/cmd/nvidia-dra-plugin/nvlib.go b/cmd/nvidia-dra-plugin/nvlib.go index 92f5bac0..ee293c50 100644 --- a/cmd/nvidia-dra-plugin/nvlib.go +++ b/cmd/nvidia-dra-plugin/nvlib.go @@ -31,8 +31,9 @@ import ( type deviceLib struct { nvdev.Interface - nvmllib nvml.Interface - nvidiaSMIPath string + nvmllib nvml.Interface + driverLibraryPath string + nvidiaSMIPath string } func newDeviceLib(driverRoot root) (*deviceLib, error) { @@ -46,32 +47,40 @@ func newDeviceLib(driverRoot root) (*deviceLib, error) { return nil, fmt.Errorf("failed to locate nvidia-smi: %w", err) } - // In order for nvidia-smi to run, we need to set the PATH to the parent of - // the nvidia-smi executable and update LD_PRELOAD to include the path to - // libnvidia-ml.so.1 - updatePathListEnvvar("LD_PRELOAD", driverLibraryPath) - updatePathListEnvvar("PATH", filepath.Dir(nvidiaSMIPath)) - // We construct an NVML library specifying the path to libnvidia-ml.so.1 // explicitly so that we don't have to rely on the library path. nvmllib := nvml.New( nvml.WithLibraryPath(driverLibraryPath), ) d := deviceLib{ - Interface: nvdev.New(nvdev.WithNvml(nvmllib)), - nvmllib: nvmllib, - nvidiaSMIPath: nvidiaSMIPath, + Interface: nvdev.New(nvdev.WithNvml(nvmllib)), + nvmllib: nvmllib, + driverLibraryPath: driverLibraryPath, + nvidiaSMIPath: nvidiaSMIPath, } return &d, nil } -// updatePathListEnvvar prepends a specified list of strings to a specified envvar. -func updatePathListEnvvar(envvar string, prepend ...string) { +// prependPathListEnvvar prepends a specified list of strings to a specified envvar and returns its value. +func prependPathListEnvvar(envvar string, prepend ...string) string { if len(prepend) == 0 { - return + return os.Getenv(envvar) } current := filepath.SplitList(os.Getenv(envvar)) - os.Setenv(envvar, strings.Join(append(prepend, current...), string(filepath.ListSeparator))) + return strings.Join(append(prepend, current...), string(filepath.ListSeparator)) +} + +// setOrOverrideEnvvar adds or updates an envar to the list of specified envvars and returns it. +func setOrOverrideEnvvar(envvars []string, key, value string) []string { + var updated []string + for _, envvar := range envvars { + pair := strings.SplitN(envvar, "=", 2) + if pair[0] == key { + continue + } + updated = append(updated, envvar) + } + return append(updated, fmt.Sprintf("%s=%s", key, value)) } func (l deviceLib) Init() error { @@ -481,10 +490,14 @@ func walkMigDevices(d nvml.Device, f func(i int, d nvml.Device) error) error { func (l deviceLib) setTimeSlice(uuids []string, timeSlice int) error { for _, uuid := range uuids { cmd := exec.Command( - "nvidia-smi", + l.nvidiaSMIPath, "compute-policy", "-i", uuid, "--set-timeslice", fmt.Sprintf("%d", timeSlice)) + + // In order for nvidia-smi to run, we need update LD_PRELOAD to include the path to libnvidia-ml.so.1. + cmd.Env := setOrOverrideEnvvar(os.Environ(), "LD_PRELOAD", prependPathListEnvvar("LD_PRELOAD", l.driverLibraryPath)) + output, err := cmd.CombinedOutput() if err != nil { klog.Errorf("\n%v", string(output)) @@ -497,9 +510,13 @@ func (l deviceLib) setTimeSlice(uuids []string, timeSlice int) error { func (l deviceLib) setComputeMode(uuids []string, mode string) error { for _, uuid := range uuids { cmd := exec.Command( - "nvidia-smi", + l.nvidiaSMIPath, "-i", uuid, "-c", mode) + + // In order for nvidia-smi to run, we need update LD_PRELOAD to include the path to libnvidia-ml.so.1. + cmd.Env := setOrOverrideEnvvar(os.Environ(), "LD_PRELOAD", prependPathListEnvvar("LD_PRELOAD", l.driverLibraryPath)) + output, err := cmd.CombinedOutput() if err != nil { klog.Errorf("\n%v", string(output))