Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable device node creation in CDI mode #927

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/nvidia-cdi-hook/commands/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat"
disabledevicenodemodification "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/disable-device-node-modification"
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
Expand All @@ -34,5 +35,6 @@ func New(logger logger.Interface) []*cli.Command {
symlinks.NewCommand(logger),
chmod.NewCommand(logger),
cudacompat.NewCommand(logger),
disabledevicenodemodification.NewCommand(logger),
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package disabledevicenodemodification

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"

"github.com/urfave/cli/v2"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

const (
nvidiaDriverParamsPath = "/proc/driver/nvidia/params"
)

type command struct {
logger logger.Interface
}

type options struct {
containerSpec string
}

// NewCommand constructs an disable-device-node-modification subcommand with the specified logger
func NewCommand(logger logger.Interface) *cli.Command {
c := command{
logger: logger,
}
return c.build()
}

func (m command) build() *cli.Command {
cfg := options{}

c := cli.Command{
Name: "disable-device-node-modification",
Usage: "Ensure that the /proc/driver/nvidia/params file present in the container does not allow device node modifications.",
Before: func(c *cli.Context) error {
return m.validateFlags(c, &cfg)
},
Action: func(c *cli.Context) error {
return m.run(c, &cfg)
},
}

c.Flags = []cli.Flag{
&cli.StringFlag{
Name: "container-spec",
Hidden: true,
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
Destination: &cfg.containerSpec,
},
}

return &c
}

func (m command) validateFlags(c *cli.Context, cfg *options) error {
return nil
}

func (m command) run(c *cli.Context, cfg *options) error {
// TODO: Do we need to prefix the driver root?
hostNvidiaParamsFile, err := os.Open(nvidiaDriverParamsPath)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return fmt.Errorf("failed to load params file: %w", err)
}
defer hostNvidiaParamsFile.Close()

s, err := oci.LoadContainerState(cfg.containerSpec)
if err != nil {
return fmt.Errorf("failed to load container state: %w", err)
}

containerRoot, err := s.GetContainerRoot()
if err != nil {
return fmt.Errorf("failed to determined container root: %w", err)
}

return m.updateNvidiaParamsFromReader(hostNvidiaParamsFile, containerRoot)
}

func (m command) updateNvidiaParamsFromReader(r io.Reader, containerRoot string) error {
modifiedContents, err := m.getModifiedParamsFileContentsFromReader(r)
if err != nil {
return fmt.Errorf("failed to generate modified contents: %w", err)
}
if len(modifiedContents) == 0 {
m.logger.Debugf("No modification required")
return nil
}
return createParamsFileInContainer(containerRoot, modifiedContents)
}

// getModifiedParamsFileContentsFromReader returns the contents of a modified params file from the specified reader.
func (m command) getModifiedParamsFileContentsFromReader(r io.Reader) ([]byte, error) {
var modified bytes.Buffer
scanner := bufio.NewScanner(r)

var requiresModification bool
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "ModifyDeviceFiles: ") {
if line == "ModifyDeviceFiles: 0" {
m.logger.Debugf("Device node modification is already disabled")
return nil, nil
}
if line == "ModifyDeviceFiles: 1" {
line = "ModifyDeviceFiles: 0"
requiresModification = true
}
}
if _, err := modified.WriteString(line + "\n"); err != nil {
return nil, fmt.Errorf("failed to create output buffer: %w", err)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read params file: %w", err)
}

if !requiresModification {
return nil, nil
}

return modified.Bytes(), nil
}

func createParamsFileInContainer(containerRoot string, contents []byte) error {
if len(contents) == 0 {
return nil
}

tempParamsFileName, err := createFileInTempfs("nvct-params", contents, 0o444)
if err != nil {
return fmt.Errorf("failed to create temporary file: %w", err)
}

if err := bindMountReadonly(tempParamsFileName, filepath.Join(containerRoot, nvidiaDriverParamsPath)); err != nil {
return fmt.Errorf("failed to create temporary params file mount: %w", err)
}

return nil
}

// createFileInTempfs creates a file with the specified name, contents, and mode in a tmpfs.
// A tmpfs is created at /tmp/nvct-emtpy-dir* with a size sufficient for the specified contents.
func createFileInTempfs(name string, contents []byte, mode os.FileMode) (string, error) {
tmpRoot, err := os.MkdirTemp("", "nvct-empty-dir*")
if err != nil {
return "", fmt.Errorf("failed to create temporary folder: %w", err)
}
if err := createTmpFs(tmpRoot, len(contents)); err != nil {
return "", fmt.Errorf("failed to create tmpfs mount for params file: %w", err)
}

filename := filepath.Join(tmpRoot, name)
fileInTempfs, err := os.Create(filename)
if err != nil {
return "", fmt.Errorf("failed to create temporary params file: %w", err)
}
defer fileInTempfs.Close()

if _, err := fileInTempfs.Write(contents); err != nil {
return "", fmt.Errorf("failed to write temporary params file: %w", err)
}

if err := fileInTempfs.Chmod(mode); err != nil {
return "", fmt.Errorf("failed to set permissions on temporary params file: %w", err)
}
return filename, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package disabledevicenodemodification

import (
"bytes"
"testing"

testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)

func TestGetModifiedParamsFileContentsFromReader(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := map[string]struct {
contents []byte
expectedError error
expectedContents []byte
}{
"no contents": {
contents: nil,
expectedError: nil,
expectedContents: nil,
},
"other contents are ignored": {
contents: []byte(`# Some other content
that we don't care about
`),
expectedError: nil,
expectedContents: nil,
},
"already zero requires no modification": {
contents: []byte("ModifyDeviceFiles: 0"),
expectedError: nil,
expectedContents: nil,
},
"leading spaces require no modification": {
contents: []byte(" ModifyDeviceFiles: 1"),
},
"Trailing spaces require no modification": {
contents: []byte("ModifyDeviceFiles: 1 "),
},
"Not 1 require no modification": {
contents: []byte("ModifyDeviceFiles: 11"),
},
"single line requires modification": {
contents: []byte("ModifyDeviceFiles: 1"),
expectedError: nil,
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
},
"single line with trailing newline requires modification": {
contents: []byte("ModifyDeviceFiles: 1\n"),
expectedError: nil,
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
},
"other content is maintained": {
contents: []byte(`ModifyDeviceFiles: 1
other content
that
is maintained`),
expectedError: nil,
expectedContents: []byte(`ModifyDeviceFiles: 0
other content
that
is maintained
`),
},
}

for description, tc := range testCases {
t.Run(description, func(t *testing.T) {
c := command{
logger: logger,
}
contents, err := c.getModifiedParamsFileContentsFromReader(bytes.NewReader(tc.contents))
require.EqualValues(t, tc.expectedError, err)
require.EqualValues(t, string(tc.expectedContents), string(contents))
})
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//go:build linux
// +build linux

/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package disabledevicenodemodification

import (
"fmt"

"golang.org/x/sys/unix"
)

func createTmpFs(target string, size int) error {
return unix.Mount("tmpfs", target, "tmpfs", 0, fmt.Sprintf("size=%d", size))
}

func bindMountReadonly(source string, target string) error {
return unix.Mount(source, target, "", unix.MS_BIND|unix.MS_RDONLY|unix.MS_NOSYMFOLLOW, "")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//go:build !linux
// +build !linux

/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package disabledevicenodemodification

import (
"fmt"
)

func createTmpFs(target string, size int) error {
return fmt.Errorf("not supported")
}

func bindMountReadonly(source string, target string) error {
return fmt.Errorf("not supported")
}
5 changes: 5 additions & 0 deletions cmd/nvidia-ctk/cdi/generate/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ containerEdits:
- update-ldcache
- --folder
- /lib/x86_64-linux-gnu
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- disable-device-node-modification
mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
Expand Down
Loading