Skip to content

Commit 62483b5

Browse files
committed
Add a hook to update nvidia params
If required, this hook creates a modified params file (with ModifyDeviceFiles: 0) in a tmpfs and mounts this over /proc/driver/nvidia/params. This prevents device node creation when running tools such as nvidia-smi. Signed-off-by: Evan Lezar <[email protected]>
1 parent e7a0067 commit 62483b5

File tree

6 files changed

+351
-0
lines changed

6 files changed

+351
-0
lines changed

Diff for: cmd/nvidia-cdi-hook/commands/commands.go

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
2424
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat"
2525
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
26+
nvidiaparams "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-nvidia-params"
2627
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2728
)
2829

@@ -34,5 +35,6 @@ func New(logger logger.Interface) []*cli.Command {
3435
symlinks.NewCommand(logger),
3536
chmod.NewCommand(logger),
3637
cudacompat.NewCommand(logger),
38+
nvidiaparams.NewCommand(logger),
3739
}
3840
}
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//go:build linux
2+
// +build linux
3+
4+
/**
5+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
**/
19+
20+
package nvidiaparams
21+
22+
import (
23+
"fmt"
24+
25+
"golang.org/x/sys/unix"
26+
)
27+
28+
func createTmpFs(target string, size int) error {
29+
return unix.Mount("tmpfs", target, "tmpfs", 0, fmt.Sprintf("size=%d", size))
30+
}
31+
32+
func bindMountReadonly(source string, target string) error {
33+
return unix.Mount(source, target, "", unix.MS_BIND|unix.MS_RDONLY|unix.MS_NOSYMFOLLOW, "")
34+
}
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//go:build !linux
2+
// +build !linux
3+
4+
/**
5+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
**/
19+
20+
package nvidiaparams
21+
22+
import (
23+
"fmt"
24+
)
25+
26+
func createTmpFs(target string, size int) error {
27+
return fmt.Errorf("not supported")
28+
}
29+
30+
func bindMountReadonly(source string, target string) error {
31+
return fmt.Errorf("not supported")
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/**
2+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvidiaparams
18+
19+
import (
20+
"bufio"
21+
"bytes"
22+
"errors"
23+
"fmt"
24+
"io"
25+
"os"
26+
"path/filepath"
27+
"strings"
28+
29+
"github.com/urfave/cli/v2"
30+
31+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
32+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
33+
)
34+
35+
const (
36+
nvidiaDriverParamsPath = "/proc/driver/nvidia/params"
37+
)
38+
39+
type command struct {
40+
logger logger.Interface
41+
}
42+
43+
type options struct {
44+
containerSpec string
45+
}
46+
47+
// NewCommand constructs an update-nvidia-params command with the specified logger
48+
func NewCommand(logger logger.Interface) *cli.Command {
49+
c := command{
50+
logger: logger,
51+
}
52+
return c.build()
53+
}
54+
55+
// build the update-nvidia-params command
56+
func (m command) build() *cli.Command {
57+
cfg := options{}
58+
59+
// Create the 'update-nvidia-params' command
60+
c := cli.Command{
61+
Name: "update-nvidia-params",
62+
Usage: "Update the /proc/driver/nvidia/params file in the container to disable device node modification.",
63+
Before: func(c *cli.Context) error {
64+
return m.validateFlags(c, &cfg)
65+
},
66+
Action: func(c *cli.Context) error {
67+
return m.run(c, &cfg)
68+
},
69+
}
70+
71+
c.Flags = []cli.Flag{
72+
&cli.StringFlag{
73+
Name: "container-spec",
74+
Hidden: true,
75+
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
76+
Destination: &cfg.containerSpec,
77+
},
78+
}
79+
80+
return &c
81+
}
82+
83+
func (m command) validateFlags(c *cli.Context, cfg *options) error {
84+
return nil
85+
}
86+
87+
func (m command) run(c *cli.Context, cfg *options) error {
88+
// TODO: Do we need to prefix the driver root?
89+
hostNvidiaParamsFile, err := os.Open(nvidiaDriverParamsPath)
90+
if errors.Is(err, os.ErrNotExist) {
91+
return nil
92+
}
93+
if err != nil {
94+
return fmt.Errorf("failed to load params file: %w", err)
95+
}
96+
defer hostNvidiaParamsFile.Close()
97+
98+
s, err := oci.LoadContainerState(cfg.containerSpec)
99+
if err != nil {
100+
return fmt.Errorf("failed to load container state: %v", err)
101+
}
102+
103+
containerRoot, err := s.GetContainerRoot()
104+
if err != nil {
105+
return fmt.Errorf("failed to determined container root: %v", err)
106+
}
107+
108+
return m.updateNvidiaParamsFromReader(hostNvidiaParamsFile, containerRoot)
109+
}
110+
111+
func (m command) updateNvidiaParamsFromReader(r io.Reader, containerRoot string) error {
112+
modifiedContents, err := m.getModifiedParamsFileContentsFromReader(r)
113+
if err != nil {
114+
return fmt.Errorf("failed to generate modified contents: %w", err)
115+
}
116+
if len(modifiedContents) == 0 {
117+
m.logger.Debugf("No modification required")
118+
return nil
119+
}
120+
return createParamsFileInContainer(containerRoot, modifiedContents)
121+
}
122+
123+
// getModifiedParamsFileContentsFromReader returns the contents of a modified params file from the specified reader.
124+
func (m command) getModifiedParamsFileContentsFromReader(r io.Reader) ([]byte, error) {
125+
var modified bytes.Buffer
126+
scanner := bufio.NewScanner(r)
127+
128+
var requiresModification bool
129+
for scanner.Scan() {
130+
line := scanner.Text()
131+
if strings.HasPrefix(line, "ModifyDeviceFiles: ") {
132+
if line == "ModifyDeviceFiles: 0" {
133+
m.logger.Debugf("Device node modification is already disabled")
134+
return nil, nil
135+
}
136+
if line == "ModifyDeviceFiles: 1" {
137+
line = "ModifyDeviceFiles: 0"
138+
requiresModification = true
139+
}
140+
}
141+
if _, err := modified.WriteString(line + "\n"); err != nil {
142+
return nil, fmt.Errorf("failed to create output buffer: %w", err)
143+
}
144+
}
145+
if err := scanner.Err(); err != nil {
146+
return nil, fmt.Errorf("failed to read params file: %w", err)
147+
}
148+
149+
if !requiresModification {
150+
return nil, nil
151+
}
152+
153+
return modified.Bytes(), nil
154+
}
155+
156+
func createParamsFileInContainer(containerRoot string, contents []byte) error {
157+
if len(contents) == 0 {
158+
return nil
159+
}
160+
161+
tempParamsFileName, err := createFileInTempfs("nvct-params", contents, 0o444)
162+
if err != nil {
163+
return fmt.Errorf("failed to create temporary file: %w", err)
164+
}
165+
166+
if err := bindMountReadonly(tempParamsFileName, filepath.Join(containerRoot, nvidiaDriverParamsPath)); err != nil {
167+
return fmt.Errorf("failed to create temporary parms file mount: %w", err)
168+
}
169+
170+
return nil
171+
}
172+
173+
// createFileInTempfs creates a file with the specified name, contents, and mode in a tmpfs.
174+
// A tmpfs is created at /tmp/nvct-emtpy-dir* with a size sufficient for the specified contents.
175+
func createFileInTempfs(name string, contents []byte, mode os.FileMode) (string, error) {
176+
tmpRoot, err := os.MkdirTemp("", "nvct-empty-dir*")
177+
if err != nil {
178+
return "", fmt.Errorf("failed to create temporary folder: %w", err)
179+
}
180+
if err := createTmpFs(tmpRoot, len(contents)); err != nil {
181+
return "", fmt.Errorf("failed to create tmpfs mount for params file: %w", err)
182+
}
183+
184+
filename := filepath.Join(tmpRoot, name)
185+
fileInTempfs, err := os.Create(filename)
186+
if err != nil {
187+
return "", fmt.Errorf("failed to create temporary params file: %w", err)
188+
}
189+
defer fileInTempfs.Close()
190+
191+
if _, err := fileInTempfs.Write(contents); err != nil {
192+
return "", fmt.Errorf("failed to write temporary params file: %w", err)
193+
}
194+
195+
if err := fileInTempfs.Chmod(mode); err != nil {
196+
return "", fmt.Errorf("failed to set permissions on temporary params file: %w", err)
197+
}
198+
return filename, nil
199+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package nvidiaparams
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
testlog "github.com/sirupsen/logrus/hooks/test"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestGetModifiedParamsFileContentsFromReader(t *testing.T) {
12+
logger, _ := testlog.NewNullLogger()
13+
testCases := map[string]struct {
14+
contents []byte
15+
expectedError error
16+
expectedContents []byte
17+
}{
18+
"no contents": {
19+
contents: nil,
20+
expectedError: nil,
21+
expectedContents: nil,
22+
},
23+
"other contents are ignored": {
24+
contents: []byte(`# Some other content
25+
that we don't care about
26+
`),
27+
expectedError: nil,
28+
expectedContents: nil,
29+
},
30+
"already zero requires no modification": {
31+
contents: []byte("ModifyDeviceFiles: 0"),
32+
expectedError: nil,
33+
expectedContents: nil,
34+
},
35+
"leading spaces require no modification": {
36+
contents: []byte(" ModifyDeviceFiles: 1"),
37+
},
38+
"Trailing spaces require no modification": {
39+
contents: []byte("ModifyDeviceFiles: 1 "),
40+
},
41+
"Not 1 require no modification": {
42+
contents: []byte("ModifyDeviceFiles: 11"),
43+
},
44+
"single line requires modification": {
45+
contents: []byte("ModifyDeviceFiles: 1"),
46+
expectedError: nil,
47+
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
48+
},
49+
"single line with trailing newline requires modification": {
50+
contents: []byte("ModifyDeviceFiles: 1\n"),
51+
expectedError: nil,
52+
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
53+
},
54+
"other content is maintained": {
55+
contents: []byte(`ModifyDeviceFiles: 1
56+
other content
57+
that
58+
is maintained`),
59+
expectedError: nil,
60+
expectedContents: []byte(`ModifyDeviceFiles: 0
61+
other content
62+
that
63+
is maintained
64+
`),
65+
},
66+
}
67+
68+
for description, tc := range testCases {
69+
t.Run(description, func(t *testing.T) {
70+
c := command{
71+
logger: logger,
72+
}
73+
contents, err := c.getModifiedParamsFileContentsFromReader(bytes.NewReader(tc.contents))
74+
require.EqualValues(t, tc.expectedError, err)
75+
require.EqualValues(t, string(tc.expectedContents), string(contents))
76+
})
77+
}
78+
79+
}

Diff for: cmd/nvidia-ctk-installer/container/toolkit/toolkit_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ containerEdits:
102102
- update-ldcache
103103
- --folder
104104
- /lib/x86_64-linux-gnu
105+
- hookName: createContainer
106+
path: {{ .toolkitRoot }}/nvidia-cdi-hook
107+
args:
108+
- nvidia-cdi-hook
109+
- update-nvidia-params
105110
mounts:
106111
- hostPath: /host/driver/root/lib/x86_64-linux-gnu/libcuda.so.999.88.77
107112
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77

0 commit comments

Comments
 (0)