Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package transport

import (
"context"
"errors"
"io"
"net"
Expand Down Expand Up @@ -208,6 +209,7 @@ type Net interface {
// The following functions are extensions to Go's standard net package

CreateDialer(dialer *net.Dialer) Dialer
CreateListenConfig(listenerConfig *net.ListenConfig) ListenConfig
}

// Dialer is identical to net.Dialer excepts that its methods
Expand All @@ -217,6 +219,14 @@ type Dialer interface {
Dial(network, address string) (net.Conn, error)
}

// ListenConfig is identical to net.ListenConfig except that its methods
// (Listen, ListenPacket) are overridden to use the Net interface.
// Use vnet.Create:ListenConfig() to create an instance of this ListenConfig.
type ListenConfig interface {
Listen(ctx context.Context, network, address string) (net.Listener, error)
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
}

// UDPConn is packet-oriented connection for UDP.
type UDPConn interface {
// Close closes the connection.
Expand Down
18 changes: 18 additions & 0 deletions stdnet/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package stdnet

import (
"context"
"fmt"
"net"

Expand Down Expand Up @@ -166,3 +167,20 @@ func (d stdDialer) Dial(network, address string) (net.Conn, error) {
func (n *Net) CreateDialer(d *net.Dialer) transport.Dialer {
return stdDialer{d}
}

type stdListenConfig struct {
*net.ListenConfig
}

func (d stdListenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) {
return d.ListenConfig.Listen(ctx, network, address)
}

func (d stdListenConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
return d.ListenConfig.ListenPacket(ctx, network, address)
}

// CreateListenConfig creates an instance of vnet.ListenConfig.
func (n *Net) CreateListenConfig(d *net.ListenConfig) transport.ListenConfig {
return stdListenConfig{d}
}
65 changes: 64 additions & 1 deletion stdnet/net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
package stdnet

import (
"context"
"net"
"testing"

"github.com/pion/logging"
"github.com/stretchr/testify/assert"
)

func TestStdNet(t *testing.T) { //nolint:cyclop
func TestStdNet(t *testing.T) { //nolint:cyclop,maintidx
log := logging.NewDefaultLoggerFactory().NewLogger("test")

t.Run("Interfaces", func(t *testing.T) {
Expand Down Expand Up @@ -194,6 +195,68 @@ func TestStdNet(t *testing.T) { //nolint:cyclop
assert.NoError(t, conn.Close(), "should succeed")
})

t.Run("Listen", func(t *testing.T) {
nw, err := NewNet()
assert.Nil(t, err, "should succeed")

listenConfig := nw.CreateListenConfig(&net.ListenConfig{})
listener, err := listenConfig.Listen(context.Background(), "tcp4", "127.0.0.1:1234")
assert.NoError(t, err, "should succeed")

laddr := listener.Addr()
log.Debugf("laddr: %s", laddr.String())

dialer := nw.CreateDialer(&net.Dialer{
LocalAddr: &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
},
})

conn, err := dialer.Dial("tcp4", "127.0.0.1:1234")
assert.NoError(t, err, "should succeed")

raddr := conn.RemoteAddr()
log.Debugf("raddr: %s", raddr.String())

assert.Equal(t, "127.0.0.1", laddr.(*net.TCPAddr).IP.String(), "should match") //nolint:forcetypeassert
assert.True(t, laddr.(*net.TCPAddr).Port != 0, "should match") //nolint:forcetypeassert
assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match")
assert.NoError(t, conn.Close(), "should succeed")
assert.NoError(t, listener.Close(), "should succeed")
})

t.Run("ListenPacket", func(t *testing.T) {
nw, err := NewNet()
assert.Nil(t, err, "should succeed")

listenConfig := nw.CreateListenConfig(&net.ListenConfig{})
packetListener, err := listenConfig.ListenPacket(context.Background(), udpString, "127.0.0.1:1234")
assert.NoError(t, err, "should succeed")

laddr := packetListener.LocalAddr()
log.Debugf("laddr: %s", laddr.String())

dialer := nw.CreateDialer(&net.Dialer{
LocalAddr: &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
},
})

packetConn, err := dialer.Dial(udpString, "127.0.0.1:1234")
assert.NoError(t, err, "should succeed")

raddr := packetConn.RemoteAddr()
log.Debugf("raddr: %s", raddr.String())

assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert
assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert
assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match")
assert.NoError(t, packetConn.Close(), "should succeed")
assert.NoError(t, packetListener.Close(), "should succeed")
})

t.Run("Unexpected operations", func(t *testing.T) {
// For portability of test, find a name of loopback interface name first
var loName string
Expand Down
22 changes: 22 additions & 0 deletions vnet/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package vnet

import (
"context"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -665,3 +666,24 @@ type dialer struct {
func (d *dialer) Dial(network, address string) (net.Conn, error) {
return d.net.Dial(network, address)
}

// CreateListenConfig creates an instance of vnet.ListenConfig.
func (v *Net) CreateListenConfig(l *net.ListenConfig) transport.ListenConfig {
return &listenConfig{
listenConfig: l,
net: v,
}
}

type listenConfig struct {
listenConfig *net.ListenConfig
net *Net
}

func (l *listenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) {
return l.listenConfig.Listen(ctx, network, address)
}

func (l *listenConfig) ListenPacket(_ context.Context, network, address string) (net.PacketConn, error) {
return l.net.ListenPacket(network, address)
}
56 changes: 56 additions & 0 deletions vnet/net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package vnet

import (
"context"
"fmt"
"net"
"testing"
Expand Down Expand Up @@ -710,6 +711,61 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx
assert.Empty(t, nw.udpConns.size(), "should match")
})

t.Run("Listen", func(t *testing.T) {
nw, err := NewNet(&NetConfig{})
assert.Nil(t, err, "should succeed")

listenConfig := nw.CreateListenConfig(&net.ListenConfig{})
listener, err := listenConfig.Listen(context.Background(), "tcp4", "127.0.0.1:1234")
assert.NoError(t, err, "should succeed")

laddr := listener.Addr()
log.Debugf("laddr: %s", laddr.String())

conn, err := net.Dial("tcp4", "127.0.0.1:1234") //nolint:noctx
assert.NoError(t, err, "should succeed")

raddr := conn.RemoteAddr()
log.Debugf("raddr: %s", raddr.String())

assert.Equal(t, "127.0.0.1", laddr.(*net.TCPAddr).IP.String(), "should match") //nolint:forcetypeassert
assert.True(t, laddr.(*net.TCPAddr).Port != 0, "should match") //nolint:forcetypeassert
assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match")
assert.NoError(t, conn.Close(), "should succeed")
assert.NoError(t, listener.Close(), "should succeed")
})

t.Run("ListenPacket", func(t *testing.T) {
nw, err := NewNet(&NetConfig{})
assert.Nil(t, err, "should succeed")

listenConfig := nw.CreateListenConfig(&net.ListenConfig{})
packetListener, err := listenConfig.ListenPacket(context.Background(), udp, "127.0.0.1:1234")
assert.NoError(t, err, "should succeed")

laddr := packetListener.LocalAddr()
log.Debugf("laddr: %s", laddr.String())

dialer := nw.CreateDialer(&net.Dialer{
LocalAddr: &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
},
})

packetConn, err := dialer.Dial(udp, "127.0.0.1:1234")
assert.NoError(t, err, "should succeed")

raddr := packetConn.RemoteAddr()
log.Debugf("raddr: %s", raddr.String())

assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert
assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert
assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match")
assert.NoError(t, packetConn.Close(), "should succeed")
assert.NoError(t, packetListener.Close(), "should succeed")
})

t.Run("Two IPs on a NIC", func(t *testing.T) {
doneCh := make(chan struct{})

Expand Down
Loading