Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 78 additions & 72 deletions libcontainer/nsenter/nsenter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"strings"
Expand All @@ -17,21 +16,9 @@ import (
"golang.org/x/sys/unix"
)

type pid struct {
Pid int `json:"Pid"`
}

type logentry struct {
Msg string `json:"msg"`
Level string `json:"level"`
}

func TestNsenterValidPaths(t *testing.T) {
args := []string{"nsenter-exec"}
parent, child, err := newPipe()
if err != nil {
t.Fatalf("failed to create pipe %v", err)
}
parent, child := newPipe(t)

namespaces := []string{
// join pid ns of the current process
Expand All @@ -47,8 +34,9 @@ func TestNsenterValidPaths(t *testing.T) {
}

if err := cmd.Start(); err != nil {
t.Fatalf("nsenter failed to start %v", err)
t.Fatalf("nsenter failed to start: %v", err)
}
child.Close()

// write cloneFlags
r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0)
Expand All @@ -66,36 +54,18 @@ func TestNsenterValidPaths(t *testing.T) {

initWaiter(t, parent)

decoder := json.NewDecoder(parent)
var pid *pid

if err := cmd.Wait(); err != nil {
t.Fatalf("nsenter exits with a non-zero exit status")
}
if err := decoder.Decode(&pid); err != nil {
dir, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/ns", os.Getpid()))
for _, d := range dir {
t.Log(d.Name())
}
t.Fatalf("%v", err)
t.Fatalf("nsenter error: %v", err)
}

p, err := os.FindProcess(pid.Pid)
if err != nil {
t.Fatalf("%v", err)
}
_, _ = p.Wait()
reapChildren(t, parent)
}

func TestNsenterInvalidPaths(t *testing.T) {
args := []string{"nsenter-exec"}
parent, child, err := newPipe()
if err != nil {
t.Fatalf("failed to create pipe %v", err)
}
parent, child := newPipe(t)

namespaces := []string{
// join pid ns of the current process
fmt.Sprintf("pid:/proc/%d/ns/pid", -1),
}
cmd := &exec.Cmd{
Expand All @@ -106,8 +76,10 @@ func TestNsenterInvalidPaths(t *testing.T) {
}

if err := cmd.Start(); err != nil {
t.Fatal(err)
t.Fatalf("nsenter failed to start: %v", err)
}
child.Close()

// write cloneFlags
r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0)
r.AddData(&libcontainer.Int32msg{
Expand All @@ -130,13 +102,9 @@ func TestNsenterInvalidPaths(t *testing.T) {

func TestNsenterIncorrectPathType(t *testing.T) {
args := []string{"nsenter-exec"}
parent, child, err := newPipe()
if err != nil {
t.Fatalf("failed to create pipe %v", err)
}
parent, child := newPipe(t)

namespaces := []string{
// join pid ns of the current process
fmt.Sprintf("net:/proc/%d/ns/pid", os.Getpid()),
}
cmd := &exec.Cmd{
Expand All @@ -147,8 +115,10 @@ func TestNsenterIncorrectPathType(t *testing.T) {
}

if err := cmd.Start(); err != nil {
t.Fatal(err)
t.Fatalf("nsenter failed to start: %v", err)
}
child.Close()

// write cloneFlags
r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0)
r.AddData(&libcontainer.Int32msg{
Expand All @@ -165,24 +135,14 @@ func TestNsenterIncorrectPathType(t *testing.T) {

initWaiter(t, parent)
if err := cmd.Wait(); err == nil {
t.Fatalf("nsenter exits with a zero exit status")
t.Fatalf("nsenter error: %v", err)
}
}

func TestNsenterChildLogging(t *testing.T) {
args := []string{"nsenter-exec"}
parent, child, err := newPipe()
if err != nil {
t.Fatalf("failed to create exec pipe %v", err)
}
logread, logwrite, err := os.Pipe()
if err != nil {
t.Fatalf("failed to create log pipe %v", err)
}
defer func() {
_ = logwrite.Close()
_ = logread.Close()
}()
parent, child := newPipe(t)
logread, logwrite := newPipe(t)

namespaces := []string{
// join pid ns of the current process
Expand All @@ -198,8 +158,11 @@ func TestNsenterChildLogging(t *testing.T) {
}

if err := cmd.Start(); err != nil {
t.Fatalf("nsenter failed to start %v", err)
t.Fatalf("nsenter failed to start: %v", err)
}
child.Close()
logwrite.Close()

// write cloneFlags
r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0)
r.AddData(&libcontainer.Int32msg{
Expand All @@ -216,20 +179,12 @@ func TestNsenterChildLogging(t *testing.T) {

initWaiter(t, parent)

logsDecoder := json.NewDecoder(logread)
var logentry *logentry

err = logsDecoder.Decode(&logentry)
if err != nil {
t.Fatalf("child log: %v", err)
}
if logentry.Level == "" || logentry.Msg == "" {
t.Fatalf("child log: empty log fields: level=\"%s\" msg=\"%s\"", logentry.Level, logentry.Msg)
}

getLogs(t, logread)
if err := cmd.Wait(); err != nil {
t.Fatalf("nsenter exits with a non-zero exit status")
t.Fatalf("nsenter error: %v", err)
}

reapChildren(t, parent)
}

func init() {
Expand All @@ -238,12 +193,19 @@ func init() {
}
}

func newPipe() (parent *os.File, child *os.File, err error) {
func newPipe(t *testing.T) (parent *os.File, child *os.File) {
t.Helper()
fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return nil, nil, err
t.Fatal("socketpair failed:", err)
}
return os.NewFile(uintptr(fds[1]), "parent"), os.NewFile(uintptr(fds[0]), "child"), nil
parent = os.NewFile(uintptr(fds[1]), "parent")
child = os.NewFile(uintptr(fds[0]), "child")
t.Cleanup(func() {
parent.Close()
child.Close()
})
return
}

// initWaiter reads back the initial \0 from runc init
Expand All @@ -261,3 +223,47 @@ func initWaiter(t *testing.T, r io.Reader) {
}
t.Fatalf("waiting for init preliminary setup: %v", err)
}

func reapChildren(t *testing.T, parent *os.File) {
t.Helper()
decoder := json.NewDecoder(parent)
decoder.DisallowUnknownFields()
var pid struct {
Pid2 int `json:"stage2_pid"`
Pid1 int `json:"stage1_pid"`
}
if err := decoder.Decode(&pid); err != nil {
t.Fatal(err)
}

// Reap children.
_, _ = unix.Wait4(pid.Pid1, nil, 0, nil)
_, _ = unix.Wait4(pid.Pid2, nil, 0, nil)

// Sanity check.
if pid.Pid1 == 0 || pid.Pid2 == 0 {
t.Fatal("got pids:", pid)
}
}

func getLogs(t *testing.T, logread *os.File) {
logsDecoder := json.NewDecoder(logread)
logsDecoder.DisallowUnknownFields()
var logentry struct {
Level string `json:"level"`
Msg string `json:"msg"`
}

for {
if err := logsDecoder.Decode(&logentry); err != nil {
if errors.Is(err, io.EOF) {
return
}
t.Fatal("init log decoding error:", err)
}
t.Logf("logentry: %+v", logentry)
if logentry.Level == "" || logentry.Msg == "" {
t.Fatalf("init log: empty log entry: %+v", logentry)
}
}
}