Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ func TestTCPClientDial(t *testing.T) {
inboundTCPConn, inboundErr := remotePeerConn.Accept()
assert.NoError(t, inboundErr)

remoteAddr := inboundTCPConn.RemoteAddr()
assert.Equal(t, allocation.Addr().String(), remoteAddr.String(), "peer conn: remote address = relay address")

_, inboundErr = inboundTCPConn.Write(expectedMsg)
assert.NoError(t, inboundErr)

Expand All @@ -450,6 +453,9 @@ func TestTCPClientDial(t *testing.T) {
connectionBindConn, err := allocation.Dial("tcp4", remotePeerAddr.String())
assert.NoError(t, err)

localAddr := connectionBindConn.LocalAddr()
assert.Equal(t, allocation.Addr().String(), localAddr.String(), "client conn: local address = relay address")

channelBindConnBuffer := make([]byte, len(expectedMsg))
_, err = connectionBindConn.Read(channelBindConnBuffer)
assert.NoError(t, err)
Expand Down Expand Up @@ -544,6 +550,154 @@ func TestTCPClientAccept(t *testing.T) {
assert.NoError(t, server.Close())
}

func TestTCPClientMultipleConns(t *testing.T) {
tcpListener, err := net.Listen("tcp4", "0.0.0.0:3478") //nolint: gosec,noctx
require.NoError(t, err)

server, err := NewServer(ServerConfig{
AuthHandler: func(ra *RequestAttributes) (key []byte, ok bool) {
return GenerateAuthKey(ra.Username, ra.Realm, "pass"), true
},
ListenerConfigs: []ListenerConfig{
{
Listener: tcpListener,
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
},
},
Realm: "pion.ly",
})
require.NoError(t, err)

clientConn, err := net.Dial("tcp", testAddr) // nolint: noctx
require.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: NewSTUNConn(clientConn),
STUNServerAddr: testAddr,
TURNServerAddr: testAddr,
Username: "foo",
Password: "pass",
})
require.NoError(t, err)
require.NoError(t, client.Listen())

allocation, err := client.AllocateTCP()
assert.NoError(t, err)

runPeerDialer := func(i int) net.Conn {
relayAddr, ok := allocation.Addr().(*net.TCPAddr)
assert.True(t, ok)

expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)}
peerConn, peerErr := net.DialTCP("tcp4", nil, relayAddr)
assert.NoError(t, peerErr)

peerBuffer := make([]byte, len(expectedMsg))
_, peerErr = peerConn.Read(peerBuffer)
assert.NoError(t, peerErr)
assert.Equal(t, expectedMsg, peerBuffer)

_, peerErr = peerConn.Write(expectedMsg)
assert.NoError(t, peerErr)

return peerConn
}

runPeerAcceptor := func(ctx context.Context, i int) (net.Listener, net.Addr) {
remotePeerListener, err := net.Listen("tcp4", "127.0.0.1:0") // nolint: noctx
assert.NoError(t, err)

expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)}
go func() {
peerConn, peerErr := remotePeerListener.Accept()
assert.NoError(t, peerErr)

peerBuffer := make([]byte, len(expectedMsg))
_, peerErr = peerConn.Read(peerBuffer)
assert.NoError(t, peerErr)
assert.Equal(t, expectedMsg, peerBuffer)

_, peerErr = peerConn.Write(expectedMsg)
assert.NoError(t, peerErr)

<-ctx.Done()

assert.NoError(t, peerConn.Close())
}()

return remotePeerListener, remotePeerListener.Addr()
}

runClientDialer := func(remotePeerAddr net.Addr, i int) net.Conn {
dialerConn, err := allocation.Dial("tcp4", remotePeerAddr.String())
assert.NoError(t, err)

expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)}
_, err = dialerConn.Write(expectedMsg)
assert.NoError(t, err)

clientBuffer := make([]byte, len(expectedMsg))
_, err = dialerConn.Read(clientBuffer)
assert.NoError(t, err)
assert.Equal(t, expectedMsg, clientBuffer)

return dialerConn
}

runClientAcceptor := func(ctx context.Context, i int) {
go func() {
acceptorConn, err := allocation.Accept()
assert.NoError(t, err)

expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)}
_, err = acceptorConn.Write(expectedMsg)
assert.NoError(t, err)

clientBuffer := make([]byte, len(expectedMsg))
_, err = acceptorConn.Read(clientBuffer)
assert.NoError(t, err)
assert.Equal(t, expectedMsg, clientBuffer)

<-ctx.Done()

assert.NoError(t, acceptorConn.Close())
}()
}

ctx, cancel := context.WithCancel(context.Background())
clientConns := []net.Conn{}
peerConns := []net.Conn{}
peerListeners := []net.Listener{}
for i := 0; i < 3; i += 1 {
// client -> server -> peer
peerListener, peerAddr := runPeerAcceptor(ctx, i)
time.Sleep(time.Second)
peerListeners = append(peerListeners, peerListener)
dialerConn := runClientDialer(peerAddr, i)
clientConns = append(clientConns, dialerConn)

// peer -> server -> client
runClientAcceptor(ctx, i)
peerConn := runPeerDialer(i)
peerConns = append(peerConns, peerConn)
}

cancel()

// Shutdown
for i := 0; i < 3; i += 1 {
assert.NoError(t, peerListeners[i].Close())
assert.NoError(t, peerConns[i].Close())
assert.NoError(t, clientConns[i].Close())
}
assert.NoError(t, allocation.Close())
assert.NoError(t, server.Close())
assert.NoError(t, clientConn.Close())
}

type channelBindFilterConn struct {
net.PacketConn

Expand Down
44 changes: 36 additions & 8 deletions internal/allocation/allocation_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type ManagerConfig struct {
LeveledLogger logging.LeveledLogger
AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
AllocateListener func(network string, requestedPort int) (net.Listener, net.Addr, error)
AllocateConn func(network string, laddr, raddr net.Addr) (net.Conn, error)
PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
EventHandler EventHandler

Expand All @@ -47,6 +48,7 @@ type Manager struct {

allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
allocateListener func(network string, requestedPort int) (net.Listener, net.Addr, error)
allocateConn func(network string, laddr, raddr net.Addr) (net.Conn, error)
permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
EventHandler EventHandler
}
Expand All @@ -58,6 +60,8 @@ func NewManager(config ManagerConfig) (*Manager, error) {
return nil, errAllocatePacketConnMustBeSet
case config.AllocateListener == nil:
return nil, errAllocateListenerMustBeSet
case config.AllocateConn == nil:
return nil, errAllocateConnMustBeSet
case config.LeveledLogger == nil:
return nil, errLeveledLoggerMustBeSet
}
Expand All @@ -72,6 +76,7 @@ func NewManager(config ManagerConfig) (*Manager, error) {
allocations: make(map[FiveTupleFingerprint]*Allocation, 64),
allocatePacketConn: config.AllocatePacketConn,
allocateListener: config.AllocateListener,
allocateConn: config.AllocateConn,
permissionHandler: config.PermissionHandler,
EventHandler: config.EventHandler,
tcpConnectionBindTimeout: tcpConnectionBindTimeout,
Expand Down Expand Up @@ -302,7 +307,22 @@ func (m *Manager) CreateTCPConnection( // nolint: cyclop
return 0, errInvalidPeerAddress
}

conn, err := net.DialTCP("tcp4", nil, &net.TCPAddr{IP: peerAddress.IP, Port: peerAddress.Port}) // nolint: noctx
relayAddr := allocation.RelayAddr
if allocation.RelayAddr == nil {
m.log.Warn("Failed to create TCP Connection: Relay address not available")

return 0, ErrTCPConnectionTimeoutOrFailure
}

remoteAddr := &net.TCPAddr{IP: peerAddress.IP, Port: peerAddress.Port}

m.lock.Lock()
if m.isDupeTCPConnection(allocation, remoteAddr) {
return 0, ErrDupeTCPConnection
}
m.lock.Unlock()

conn, err := m.allocateConn("tcp4", relayAddr, remoteAddr) // nolint: noctx
if err != nil {
m.log.Warnf("Failed to create TCP Connection: %v", err)

Expand Down Expand Up @@ -341,13 +361,8 @@ func (m *Manager) addTCPConnection(allocation *Allocation, conn net.Conn) (proto
return 0, ErrDupeTCPConnection
}

for i := range allocation.tcpConnections {
tcpAddr, ok := allocation.tcpConnections[i].RemoteAddr().(*net.TCPAddr)
if !ok {
return 0, ErrDupeTCPConnection
} else if tcpAddr.IP.Equal(newConnAddr.IP) && tcpAddr.Port == newConnAddr.Port {
return 0, ErrDupeTCPConnection
}
if m.isDupeTCPConnection(allocation, newConnAddr) {
return 0, ErrDupeTCPConnection
}

tcpConn := &tcpConnection{conn, atomic.Bool{}, nil}
Expand All @@ -362,6 +377,19 @@ func (m *Manager) addTCPConnection(allocation *Allocation, conn net.Conn) (proto
return connectionID, nil
}

func (m *Manager) isDupeTCPConnection(allocation *Allocation, remoteAddr *net.TCPAddr) bool {
for i := range allocation.tcpConnections {
tcpAddr, ok := allocation.tcpConnections[i].RemoteAddr().(*net.TCPAddr)
if !ok {
return true
} else if tcpAddr.IP.Equal(remoteAddr.IP) && tcpAddr.Port == remoteAddr.Port {
return true
}
}

return false
}

// GetTCPConnection returns the TCP Connection for the given ConnectionID.
func (m *Manager) GetTCPConnection(username string, connectionID proto.ConnectionID) net.Conn {
m.lock.Lock()
Expand Down
Loading
Loading