diff --git a/net.go b/net.go index 10e30a0..1de02ed 100644 --- a/net.go +++ b/net.go @@ -6,6 +6,7 @@ package transport import ( + "context" "errors" "io" "net" @@ -208,6 +209,7 @@ type Net interface { // The following functions are extensions to Go's standard net package CreateDialer(dialer *net.Dialer) Dialer + CreateListenConfig(listenerConfig *net.ListenConfig) ListenConfig } // Dialer is identical to net.Dialer excepts that its methods @@ -217,6 +219,14 @@ type Dialer interface { Dial(network, address string) (net.Conn, error) } +// ListenConfig is identical to net.ListenConfig except that its methods +// (Listen, ListenPacket) are overridden to use the Net interface. +// Use vnet.Create:ListenConfig() to create an instance of this ListenConfig. +type ListenConfig interface { + Listen(ctx context.Context, network, address string) (net.Listener, error) + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + // UDPConn is packet-oriented connection for UDP. type UDPConn interface { // Close closes the connection. diff --git a/stdnet/net.go b/stdnet/net.go index e0b923c..1a9c23f 100644 --- a/stdnet/net.go +++ b/stdnet/net.go @@ -6,6 +6,7 @@ package stdnet import ( + "context" "fmt" "net" @@ -166,3 +167,20 @@ func (d stdDialer) Dial(network, address string) (net.Conn, error) { func (n *Net) CreateDialer(d *net.Dialer) transport.Dialer { return stdDialer{d} } + +type stdListenConfig struct { + *net.ListenConfig +} + +func (d stdListenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) { + return d.ListenConfig.Listen(ctx, network, address) +} + +func (d stdListenConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + return d.ListenConfig.ListenPacket(ctx, network, address) +} + +// CreateListenConfig creates an instance of vnet.ListenConfig. +func (n *Net) CreateListenConfig(d *net.ListenConfig) transport.ListenConfig { + return stdListenConfig{d} +} diff --git a/stdnet/net_test.go b/stdnet/net_test.go index db1c69f..e0c9fb3 100644 --- a/stdnet/net_test.go +++ b/stdnet/net_test.go @@ -7,6 +7,7 @@ package stdnet import ( + "context" "net" "testing" @@ -14,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestStdNet(t *testing.T) { //nolint:cyclop +func TestStdNet(t *testing.T) { //nolint:cyclop,maintidx log := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("Interfaces", func(t *testing.T) { @@ -194,6 +195,68 @@ func TestStdNet(t *testing.T) { //nolint:cyclop assert.NoError(t, conn.Close(), "should succeed") }) + t.Run("Listen", func(t *testing.T) { + nw, err := NewNet() + assert.Nil(t, err, "should succeed") + + listenConfig := nw.CreateListenConfig(&net.ListenConfig{}) + listener, err := listenConfig.Listen(context.Background(), "tcp4", "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + + laddr := listener.Addr() + log.Debugf("laddr: %s", laddr.String()) + + dialer := nw.CreateDialer(&net.Dialer{ + LocalAddr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }, + }) + + conn, err := dialer.Dial("tcp4", "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + + raddr := conn.RemoteAddr() + log.Debugf("raddr: %s", raddr.String()) + + assert.Equal(t, "127.0.0.1", laddr.(*net.TCPAddr).IP.String(), "should match") //nolint:forcetypeassert + assert.True(t, laddr.(*net.TCPAddr).Port != 0, "should match") //nolint:forcetypeassert + assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") + assert.NoError(t, conn.Close(), "should succeed") + assert.NoError(t, listener.Close(), "should succeed") + }) + + t.Run("ListenPacket", func(t *testing.T) { + nw, err := NewNet() + assert.Nil(t, err, "should succeed") + + listenConfig := nw.CreateListenConfig(&net.ListenConfig{}) + packetListener, err := listenConfig.ListenPacket(context.Background(), udpString, "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + + laddr := packetListener.LocalAddr() + log.Debugf("laddr: %s", laddr.String()) + + dialer := nw.CreateDialer(&net.Dialer{ + LocalAddr: &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }, + }) + + packetConn, err := dialer.Dial(udpString, "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + + raddr := packetConn.RemoteAddr() + log.Debugf("raddr: %s", raddr.String()) + + assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert + assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert + assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") + assert.NoError(t, packetConn.Close(), "should succeed") + assert.NoError(t, packetListener.Close(), "should succeed") + }) + t.Run("Unexpected operations", func(t *testing.T) { // For portability of test, find a name of loopback interface name first var loName string diff --git a/vnet/net.go b/vnet/net.go index 8fb5134..ced35f2 100644 --- a/vnet/net.go +++ b/vnet/net.go @@ -4,6 +4,7 @@ package vnet import ( + "context" "encoding/binary" "errors" "fmt" @@ -665,3 +666,24 @@ type dialer struct { func (d *dialer) Dial(network, address string) (net.Conn, error) { return d.net.Dial(network, address) } + +// CreateListenConfig creates an instance of vnet.ListenConfig. +func (v *Net) CreateListenConfig(l *net.ListenConfig) transport.ListenConfig { + return &listenConfig{ + listenConfig: l, + net: v, + } +} + +type listenConfig struct { + listenConfig *net.ListenConfig + net *Net +} + +func (l *listenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) { + return l.listenConfig.Listen(ctx, network, address) +} + +func (l *listenConfig) ListenPacket(_ context.Context, network, address string) (net.PacketConn, error) { + return l.net.ListenPacket(network, address) +} diff --git a/vnet/net_test.go b/vnet/net_test.go index 79c3104..af22ac8 100644 --- a/vnet/net_test.go +++ b/vnet/net_test.go @@ -4,6 +4,7 @@ package vnet import ( + "context" "fmt" "net" "testing" @@ -710,6 +711,61 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx assert.Empty(t, nw.udpConns.size(), "should match") }) + t.Run("Listen", func(t *testing.T) { + nw, err := NewNet(&NetConfig{}) + assert.Nil(t, err, "should succeed") + + listenConfig := nw.CreateListenConfig(&net.ListenConfig{}) + listener, err := listenConfig.Listen(context.Background(), "tcp4", "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + + laddr := listener.Addr() + log.Debugf("laddr: %s", laddr.String()) + + conn, err := net.Dial("tcp4", "127.0.0.1:1234") //nolint:noctx + assert.NoError(t, err, "should succeed") + + raddr := conn.RemoteAddr() + log.Debugf("raddr: %s", raddr.String()) + + assert.Equal(t, "127.0.0.1", laddr.(*net.TCPAddr).IP.String(), "should match") //nolint:forcetypeassert + assert.True(t, laddr.(*net.TCPAddr).Port != 0, "should match") //nolint:forcetypeassert + assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") + assert.NoError(t, conn.Close(), "should succeed") + assert.NoError(t, listener.Close(), "should succeed") + }) + + t.Run("ListenPacket", func(t *testing.T) { + nw, err := NewNet(&NetConfig{}) + assert.Nil(t, err, "should succeed") + + listenConfig := nw.CreateListenConfig(&net.ListenConfig{}) + packetListener, err := listenConfig.ListenPacket(context.Background(), udp, "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + + laddr := packetListener.LocalAddr() + log.Debugf("laddr: %s", laddr.String()) + + dialer := nw.CreateDialer(&net.Dialer{ + LocalAddr: &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }, + }) + + packetConn, err := dialer.Dial(udp, "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + + raddr := packetConn.RemoteAddr() + log.Debugf("raddr: %s", raddr.String()) + + assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert + assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert + assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") + assert.NoError(t, packetConn.Close(), "should succeed") + assert.NoError(t, packetListener.Close(), "should succeed") + }) + t.Run("Two IPs on a NIC", func(t *testing.T) { doneCh := make(chan struct{})