diff --git a/libcontainer/nsenter/nsenter_test.go b/libcontainer/nsenter/nsenter_test.go index 9201f02641d..0cbf0aae61b 100644 --- a/libcontainer/nsenter/nsenter_test.go +++ b/libcontainer/nsenter/nsenter_test.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "os" "os/exec" "strings" @@ -17,21 +16,9 @@ import ( "golang.org/x/sys/unix" ) -type pid struct { - Pid int `json:"Pid"` -} - -type logentry struct { - Msg string `json:"msg"` - Level string `json:"level"` -} - func TestNsenterValidPaths(t *testing.T) { args := []string{"nsenter-exec"} - parent, child, err := newPipe() - if err != nil { - t.Fatalf("failed to create pipe %v", err) - } + parent, child := newPipe(t) namespaces := []string{ // join pid ns of the current process @@ -47,8 +34,9 @@ func TestNsenterValidPaths(t *testing.T) { } if err := cmd.Start(); err != nil { - t.Fatalf("nsenter failed to start %v", err) + t.Fatalf("nsenter failed to start: %v", err) } + child.Close() // write cloneFlags r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0) @@ -66,36 +54,18 @@ func TestNsenterValidPaths(t *testing.T) { initWaiter(t, parent) - decoder := json.NewDecoder(parent) - var pid *pid - if err := cmd.Wait(); err != nil { - t.Fatalf("nsenter exits with a non-zero exit status") - } - if err := decoder.Decode(&pid); err != nil { - dir, _ := ioutil.ReadDir(fmt.Sprintf("/proc/%d/ns", os.Getpid())) - for _, d := range dir { - t.Log(d.Name()) - } - t.Fatalf("%v", err) + t.Fatalf("nsenter error: %v", err) } - p, err := os.FindProcess(pid.Pid) - if err != nil { - t.Fatalf("%v", err) - } - _, _ = p.Wait() + reapChildren(t, parent) } func TestNsenterInvalidPaths(t *testing.T) { args := []string{"nsenter-exec"} - parent, child, err := newPipe() - if err != nil { - t.Fatalf("failed to create pipe %v", err) - } + parent, child := newPipe(t) namespaces := []string{ - // join pid ns of the current process fmt.Sprintf("pid:/proc/%d/ns/pid", -1), } cmd := &exec.Cmd{ @@ -106,8 +76,10 @@ func TestNsenterInvalidPaths(t *testing.T) { } if err := cmd.Start(); err != nil { - t.Fatal(err) + t.Fatalf("nsenter failed to start: %v", err) } + child.Close() + // write cloneFlags r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0) r.AddData(&libcontainer.Int32msg{ @@ -130,13 +102,9 @@ func TestNsenterInvalidPaths(t *testing.T) { func TestNsenterIncorrectPathType(t *testing.T) { args := []string{"nsenter-exec"} - parent, child, err := newPipe() - if err != nil { - t.Fatalf("failed to create pipe %v", err) - } + parent, child := newPipe(t) namespaces := []string{ - // join pid ns of the current process fmt.Sprintf("net:/proc/%d/ns/pid", os.Getpid()), } cmd := &exec.Cmd{ @@ -147,8 +115,10 @@ func TestNsenterIncorrectPathType(t *testing.T) { } if err := cmd.Start(); err != nil { - t.Fatal(err) + t.Fatalf("nsenter failed to start: %v", err) } + child.Close() + // write cloneFlags r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0) r.AddData(&libcontainer.Int32msg{ @@ -165,24 +135,14 @@ func TestNsenterIncorrectPathType(t *testing.T) { initWaiter(t, parent) if err := cmd.Wait(); err == nil { - t.Fatalf("nsenter exits with a zero exit status") + t.Fatalf("nsenter error: %v", err) } } func TestNsenterChildLogging(t *testing.T) { args := []string{"nsenter-exec"} - parent, child, err := newPipe() - if err != nil { - t.Fatalf("failed to create exec pipe %v", err) - } - logread, logwrite, err := os.Pipe() - if err != nil { - t.Fatalf("failed to create log pipe %v", err) - } - defer func() { - _ = logwrite.Close() - _ = logread.Close() - }() + parent, child := newPipe(t) + logread, logwrite := newPipe(t) namespaces := []string{ // join pid ns of the current process @@ -198,8 +158,11 @@ func TestNsenterChildLogging(t *testing.T) { } if err := cmd.Start(); err != nil { - t.Fatalf("nsenter failed to start %v", err) + t.Fatalf("nsenter failed to start: %v", err) } + child.Close() + logwrite.Close() + // write cloneFlags r := nl.NewNetlinkRequest(int(libcontainer.InitMsg), 0) r.AddData(&libcontainer.Int32msg{ @@ -216,20 +179,12 @@ func TestNsenterChildLogging(t *testing.T) { initWaiter(t, parent) - logsDecoder := json.NewDecoder(logread) - var logentry *logentry - - err = logsDecoder.Decode(&logentry) - if err != nil { - t.Fatalf("child log: %v", err) - } - if logentry.Level == "" || logentry.Msg == "" { - t.Fatalf("child log: empty log fields: level=\"%s\" msg=\"%s\"", logentry.Level, logentry.Msg) - } - + getLogs(t, logread) if err := cmd.Wait(); err != nil { - t.Fatalf("nsenter exits with a non-zero exit status") + t.Fatalf("nsenter error: %v", err) } + + reapChildren(t, parent) } func init() { @@ -238,12 +193,19 @@ func init() { } } -func newPipe() (parent *os.File, child *os.File, err error) { +func newPipe(t *testing.T) (parent *os.File, child *os.File) { + t.Helper() fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0) if err != nil { - return nil, nil, err + t.Fatal("socketpair failed:", err) } - return os.NewFile(uintptr(fds[1]), "parent"), os.NewFile(uintptr(fds[0]), "child"), nil + parent = os.NewFile(uintptr(fds[1]), "parent") + child = os.NewFile(uintptr(fds[0]), "child") + t.Cleanup(func() { + parent.Close() + child.Close() + }) + return } // initWaiter reads back the initial \0 from runc init @@ -261,3 +223,47 @@ func initWaiter(t *testing.T, r io.Reader) { } t.Fatalf("waiting for init preliminary setup: %v", err) } + +func reapChildren(t *testing.T, parent *os.File) { + t.Helper() + decoder := json.NewDecoder(parent) + decoder.DisallowUnknownFields() + var pid struct { + Pid2 int `json:"stage2_pid"` + Pid1 int `json:"stage1_pid"` + } + if err := decoder.Decode(&pid); err != nil { + t.Fatal(err) + } + + // Reap children. + _, _ = unix.Wait4(pid.Pid1, nil, 0, nil) + _, _ = unix.Wait4(pid.Pid2, nil, 0, nil) + + // Sanity check. + if pid.Pid1 == 0 || pid.Pid2 == 0 { + t.Fatal("got pids:", pid) + } +} + +func getLogs(t *testing.T, logread *os.File) { + logsDecoder := json.NewDecoder(logread) + logsDecoder.DisallowUnknownFields() + var logentry struct { + Level string `json:"level"` + Msg string `json:"msg"` + } + + for { + if err := logsDecoder.Decode(&logentry); err != nil { + if errors.Is(err, io.EOF) { + return + } + t.Fatal("init log decoding error:", err) + } + t.Logf("logentry: %+v", logentry) + if logentry.Level == "" || logentry.Msg == "" { + t.Fatalf("init log: empty log entry: %+v", logentry) + } + } +}