Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions cli-plugins/socket/socket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package socket

import (
"io/fs"
"net"
"os"
"runtime"
"strings"
"testing"
"time"

"gotest.tools/v3/assert"
"gotest.tools/v3/poll"
)

func TestSetupConn(t *testing.T) {
t.Run("updates conn when connected", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")

_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")

pollConnNotNil(t, &conn)
})

t.Run("allows reconnects", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")

otherConn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")

otherConn.Close()

_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial listener")
})

t.Run("does not leak sockets to local directory", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
checkDirNoPluginSocket(t)

addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")
checkDirNoPluginSocket(t)
})
}

func checkDirNoPluginSocket(t *testing.T) {
t.Helper()

files, err := os.ReadDir(".")
assert.NilError(t, err, "failed to list files in dir to check for leaked sockets")

for _, f := range files {
info, err := f.Info()
assert.NilError(t, err, "failed to check file info")
// check for a socket with `docker_cli_` in the name (from `SetupConn()`)
if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket {
t.Fatal("found socket in a local directory")
}
}
}

func TestConnectAndWait(t *testing.T) {
t.Run("calls cancel func on EOF", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err, "failed to setup listener")

done := make(chan struct{})
t.Setenv(EnvKey, listener.Addr().String())
cancelFunc := func() {
done <- struct{}{}
}
ConnectAndWait(cancelFunc)
pollConnNotNil(t, &conn)
conn.Close()

select {
case <-done:
case <-time.After(10 * time.Millisecond):
t.Fatal("cancel function not closed after 10ms")
}
})

// TODO: this test cannot be executed with `t.Parallel()`, due to
// relying on goroutine numbers to ensure correct behaviour
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err, "failed to setup listener")
t.Setenv(EnvKey, listener.Addr().String())
numGoroutines := runtime.NumGoroutine()

ConnectAndWait(func() {})
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)

pollConnNotNil(t, &conn)
conn.Close()
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if runtime.NumGoroutine() > numGoroutines+1 {
return poll.Continue("waiting for connect goroutine to exit")
}
return poll.Success()
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
})
}

func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
t.Helper()

poll.WaitOn(t, func(t poll.LogT) poll.Result {
if *conn == nil {
return poll.Continue("waiting for conn to not be nil")
}
return poll.Success()
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
}
123 changes: 123 additions & 0 deletions e2e/cli-plugins/plugins/presocket/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package main

import (
"fmt"
"os"
"os/signal"
"syscall"
"time"

"github.com/docker/cli/cli-plugins/manager"
"github.com/docker/cli/cli-plugins/plugin"
"github.com/docker/cli/cli/command"
"github.com/spf13/cobra"
)

func main() {
plugin.Run(RootCmd, manager.Metadata{
SchemaVersion: "0.1.0",
Vendor: "Docker Inc.",
Version: "test",
})
}

func RootCmd(dockerCli command.Cli) *cobra.Command {
cmd := cobra.Command{
Use: "presocket",
Short: "testing plugin that does not connect to the socket",
// override PersistentPreRunE so that the plugin default
// PersistentPreRunE doesn't run, simulating a plugin built
// with a pre-socket-communication version of the CLI
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
return nil
},
}

cmd.AddCommand(&cobra.Command{
Use: "test-no-socket",
Short: "test command that runs until it receives a SIGINT",
RunE: func(cmd *cobra.Command, args []string) error {
go func() {
<-cmd.Context().Done()
fmt.Fprintln(dockerCli.Out(), "context cancelled")
os.Exit(2)
}()
signalCh := make(chan os.Signal, 10)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
for range signalCh {
fmt.Fprintln(dockerCli.Out(), "received SIGINT")
}
}()
<-time.After(3 * time.Second)
fmt.Fprintln(dockerCli.Err(), "exit after 3 seconds")
return nil
},
})

cmd.AddCommand(&cobra.Command{
Use: "test-socket",
Short: "test command that runs until it receives a SIGINT",
PreRunE: func(cmd *cobra.Command, args []string) error {
return plugin.PersistentPreRunE(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
go func() {
<-cmd.Context().Done()
fmt.Fprintln(dockerCli.Out(), "context cancelled")
os.Exit(2)
}()
signalCh := make(chan os.Signal, 10)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
for range signalCh {
fmt.Fprintln(dockerCli.Out(), "received SIGINT")
}
}()
<-time.After(3 * time.Second)
fmt.Fprintln(dockerCli.Err(), "exit after 3 seconds")
return nil
},
})

cmd.AddCommand(&cobra.Command{
Use: "test-socket-ignore-context",
Short: "test command that runs until it receives a SIGINT",
PreRunE: func(cmd *cobra.Command, args []string) error {
return plugin.PersistentPreRunE(cmd, args)
},
RunE: func(cmd *cobra.Command, args []string) error {
signalCh := make(chan os.Signal, 10)
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
for range signalCh {
fmt.Fprintln(dockerCli.Out(), "received SIGINT")
}
}()
<-time.After(3 * time.Second)
fmt.Fprintln(dockerCli.Err(), "exit after 3 seconds")
return nil
},
})

cmd.AddCommand(&cobra.Command{
Use: "tty",
Short: "test command that attempts to read from the TTY",
RunE: func(cmd *cobra.Command, args []string) error {
done := make(chan struct{})
go func() {
b := make([]byte, 1)
_, _ = dockerCli.In().Read(b)
done <- struct{}{}
}()
select {
case <-done:
case <-time.After(2 * time.Second):
fmt.Fprint(dockerCli.Err(), "timeout after 2 seconds")
}
return nil
},
})

return &cmd
}
Loading