diff --git a/client_test.go b/client_test.go index 592fcb2e..866cfb93 100644 --- a/client_test.go +++ b/client_test.go @@ -430,6 +430,9 @@ func TestTCPClientDial(t *testing.T) { inboundTCPConn, inboundErr := remotePeerConn.Accept() assert.NoError(t, inboundErr) + remoteAddr := inboundTCPConn.RemoteAddr() + assert.Equal(t, allocation.Addr().String(), remoteAddr.String(), "peer conn: remote address = relay address") + _, inboundErr = inboundTCPConn.Write(expectedMsg) assert.NoError(t, inboundErr) @@ -450,6 +453,9 @@ func TestTCPClientDial(t *testing.T) { connectionBindConn, err := allocation.Dial("tcp4", remotePeerAddr.String()) assert.NoError(t, err) + localAddr := connectionBindConn.LocalAddr() + assert.Equal(t, allocation.Addr().String(), localAddr.String(), "client conn: local address = relay address") + channelBindConnBuffer := make([]byte, len(expectedMsg)) _, err = connectionBindConn.Read(channelBindConnBuffer) assert.NoError(t, err) @@ -544,6 +550,154 @@ func TestTCPClientAccept(t *testing.T) { assert.NoError(t, server.Close()) } +func TestTCPClientMultipleConns(t *testing.T) { + tcpListener, err := net.Listen("tcp4", "0.0.0.0:3478") //nolint: gosec,noctx + require.NoError(t, err) + + server, err := NewServer(ServerConfig{ + AuthHandler: func(ra *RequestAttributes) (key []byte, ok bool) { + return GenerateAuthKey(ra.Username, ra.Realm, "pass"), true + }, + ListenerConfigs: []ListenerConfig{ + { + Listener: tcpListener, + RelayAddressGenerator: &RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP("127.0.0.1"), + Address: "0.0.0.0", + }, + }, + }, + Realm: "pion.ly", + }) + require.NoError(t, err) + + clientConn, err := net.Dial("tcp", testAddr) // nolint: noctx + require.NoError(t, err) + + client, err := NewClient(&ClientConfig{ + Conn: NewSTUNConn(clientConn), + STUNServerAddr: testAddr, + TURNServerAddr: testAddr, + Username: "foo", + Password: "pass", + }) + require.NoError(t, err) + require.NoError(t, client.Listen()) + + allocation, err := client.AllocateTCP() + assert.NoError(t, err) + + runPeerDialer := func(i int) net.Conn { + relayAddr, ok := allocation.Addr().(*net.TCPAddr) + assert.True(t, ok) + + expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)} + peerConn, peerErr := net.DialTCP("tcp4", nil, relayAddr) + assert.NoError(t, peerErr) + + peerBuffer := make([]byte, len(expectedMsg)) + _, peerErr = peerConn.Read(peerBuffer) + assert.NoError(t, peerErr) + assert.Equal(t, expectedMsg, peerBuffer) + + _, peerErr = peerConn.Write(expectedMsg) + assert.NoError(t, peerErr) + + return peerConn + } + + runPeerAcceptor := func(ctx context.Context, i int) (net.Listener, net.Addr) { + remotePeerListener, err := net.Listen("tcp4", "127.0.0.1:0") // nolint: noctx + assert.NoError(t, err) + + expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)} + go func() { + peerConn, peerErr := remotePeerListener.Accept() + assert.NoError(t, peerErr) + + peerBuffer := make([]byte, len(expectedMsg)) + _, peerErr = peerConn.Read(peerBuffer) + assert.NoError(t, peerErr) + assert.Equal(t, expectedMsg, peerBuffer) + + _, peerErr = peerConn.Write(expectedMsg) + assert.NoError(t, peerErr) + + <-ctx.Done() + + assert.NoError(t, peerConn.Close()) + }() + + return remotePeerListener, remotePeerListener.Addr() + } + + runClientDialer := func(remotePeerAddr net.Addr, i int) net.Conn { + dialerConn, err := allocation.Dial("tcp4", remotePeerAddr.String()) + assert.NoError(t, err) + + expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)} + _, err = dialerConn.Write(expectedMsg) + assert.NoError(t, err) + + clientBuffer := make([]byte, len(expectedMsg)) + _, err = dialerConn.Read(clientBuffer) + assert.NoError(t, err) + assert.Equal(t, expectedMsg, clientBuffer) + + return dialerConn + } + + runClientAcceptor := func(ctx context.Context, i int) { + go func() { + acceptorConn, err := allocation.Accept() + assert.NoError(t, err) + + expectedMsg := []byte{0xDE, 0xAD, 0xBE, 0xEF, byte(i)} + _, err = acceptorConn.Write(expectedMsg) + assert.NoError(t, err) + + clientBuffer := make([]byte, len(expectedMsg)) + _, err = acceptorConn.Read(clientBuffer) + assert.NoError(t, err) + assert.Equal(t, expectedMsg, clientBuffer) + + <-ctx.Done() + + assert.NoError(t, acceptorConn.Close()) + }() + } + + ctx, cancel := context.WithCancel(context.Background()) + clientConns := []net.Conn{} + peerConns := []net.Conn{} + peerListeners := []net.Listener{} + for i := 0; i < 3; i += 1 { + // client -> server -> peer + peerListener, peerAddr := runPeerAcceptor(ctx, i) + time.Sleep(time.Second) + peerListeners = append(peerListeners, peerListener) + dialerConn := runClientDialer(peerAddr, i) + clientConns = append(clientConns, dialerConn) + + // peer -> server -> client + runClientAcceptor(ctx, i) + peerConn := runPeerDialer(i) + peerConns = append(peerConns, peerConn) + } + + cancel() + + // Shutdown + for i := 0; i < 3; i += 1 { + assert.NoError(t, peerListeners[i].Close()) + assert.NoError(t, peerConns[i].Close()) + assert.NoError(t, clientConns[i].Close()) + } + assert.NoError(t, allocation.Close()) + assert.NoError(t, server.Close()) + assert.NoError(t, clientConn.Close()) +} + type channelBindFilterConn struct { net.PacketConn diff --git a/internal/allocation/allocation_manager.go b/internal/allocation/allocation_manager.go index a6f49622..d7c39626 100644 --- a/internal/allocation/allocation_manager.go +++ b/internal/allocation/allocation_manager.go @@ -25,6 +25,7 @@ type ManagerConfig struct { LeveledLogger logging.LeveledLogger AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) AllocateListener func(network string, requestedPort int) (net.Listener, net.Addr, error) + AllocateConn func(network string, laddr, raddr net.Addr) (net.Conn, error) PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool EventHandler EventHandler @@ -47,6 +48,7 @@ type Manager struct { allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) allocateListener func(network string, requestedPort int) (net.Listener, net.Addr, error) + allocateConn func(network string, laddr, raddr net.Addr) (net.Conn, error) permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool EventHandler EventHandler } @@ -58,6 +60,8 @@ func NewManager(config ManagerConfig) (*Manager, error) { return nil, errAllocatePacketConnMustBeSet case config.AllocateListener == nil: return nil, errAllocateListenerMustBeSet + case config.AllocateConn == nil: + return nil, errAllocateConnMustBeSet case config.LeveledLogger == nil: return nil, errLeveledLoggerMustBeSet } @@ -72,6 +76,7 @@ func NewManager(config ManagerConfig) (*Manager, error) { allocations: make(map[FiveTupleFingerprint]*Allocation, 64), allocatePacketConn: config.AllocatePacketConn, allocateListener: config.AllocateListener, + allocateConn: config.AllocateConn, permissionHandler: config.PermissionHandler, EventHandler: config.EventHandler, tcpConnectionBindTimeout: tcpConnectionBindTimeout, @@ -302,7 +307,22 @@ func (m *Manager) CreateTCPConnection( // nolint: cyclop return 0, errInvalidPeerAddress } - conn, err := net.DialTCP("tcp4", nil, &net.TCPAddr{IP: peerAddress.IP, Port: peerAddress.Port}) // nolint: noctx + relayAddr := allocation.RelayAddr + if allocation.RelayAddr == nil { + m.log.Warn("Failed to create TCP Connection: Relay address not available") + + return 0, ErrTCPConnectionTimeoutOrFailure + } + + remoteAddr := &net.TCPAddr{IP: peerAddress.IP, Port: peerAddress.Port} + + m.lock.Lock() + if m.isDupeTCPConnection(allocation, remoteAddr) { + return 0, ErrDupeTCPConnection + } + m.lock.Unlock() + + conn, err := m.allocateConn("tcp4", relayAddr, remoteAddr) // nolint: noctx if err != nil { m.log.Warnf("Failed to create TCP Connection: %v", err) @@ -341,13 +361,8 @@ func (m *Manager) addTCPConnection(allocation *Allocation, conn net.Conn) (proto return 0, ErrDupeTCPConnection } - for i := range allocation.tcpConnections { - tcpAddr, ok := allocation.tcpConnections[i].RemoteAddr().(*net.TCPAddr) - if !ok { - return 0, ErrDupeTCPConnection - } else if tcpAddr.IP.Equal(newConnAddr.IP) && tcpAddr.Port == newConnAddr.Port { - return 0, ErrDupeTCPConnection - } + if m.isDupeTCPConnection(allocation, newConnAddr) { + return 0, ErrDupeTCPConnection } tcpConn := &tcpConnection{conn, atomic.Bool{}, nil} @@ -362,6 +377,19 @@ func (m *Manager) addTCPConnection(allocation *Allocation, conn net.Conn) (proto return connectionID, nil } +func (m *Manager) isDupeTCPConnection(allocation *Allocation, remoteAddr *net.TCPAddr) bool { + for i := range allocation.tcpConnections { + tcpAddr, ok := allocation.tcpConnections[i].RemoteAddr().(*net.TCPAddr) + if !ok { + return true + } else if tcpAddr.IP.Equal(remoteAddr.IP) && tcpAddr.Port == remoteAddr.Port { + return true + } + } + + return false +} + // GetTCPConnection returns the TCP Connection for the given ConnectionID. func (m *Manager) GetTCPConnection(username string, connectionID proto.ConnectionID) net.Conn { m.lock.Lock() diff --git a/internal/allocation/allocation_manager_test.go b/internal/allocation/allocation_manager_test.go index 9373014b..56dd9d7a 100644 --- a/internal/allocation/allocation_manager_test.go +++ b/internal/allocation/allocation_manager_test.go @@ -12,10 +12,12 @@ import ( "math/rand" "net" "strings" + "sync" "testing" "time" "github.com/pion/logging" + "github.com/pion/transport/v4/reuseport" "github.com/pion/turn/v4/internal/proto" "github.com/stretchr/testify/assert" ) @@ -37,6 +39,11 @@ func TestNewManagerValidation(t *testing.T) { cfg.AllocateListener = func(string, int) (net.Listener, net.Addr, error) { return nil, nil, nil } manager, err = NewManager(cfg) assert.Nil(t, manager) + assert.ErrorIs(t, err, errAllocateConnMustBeSet) + + cfg.AllocateConn = func(network string, laddr, raddr net.Addr) (net.Conn, error) { return nil, nil } //nolint:nilnil + manager, err = NewManager(cfg) + assert.Nil(t, manager) assert.ErrorIs(t, err, errLeveledLoggerMustBeSet) cfg.LeveledLogger = loggerFactory.NewLogger("test") @@ -212,6 +219,14 @@ func newTestManager() (*Manager, error) { return conn, conn.LocalAddr(), nil }, AllocateListener: func(string, int) (net.Listener, net.Addr, error) { return nil, nil, nil }, + AllocateConn: func(network string, laddr, raddr net.Addr) (net.Conn, error) { + dialer := net.Dialer{ + LocalAddr: laddr, + Control: reuseport.Control, + } + + return dialer.Dial(network, raddr.String()) + }, } return NewManager(config) @@ -236,6 +251,78 @@ func TestGetRandomEvenPort(t *testing.T) { } func TestCreateTCPConnection(t *testing.T) { + lns := make([]net.Listener, 3) + mu := sync.Mutex{} // make the race detector happy + acceptedConns := make([]net.Conn, 3) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var err error + for i := 0; i < 3; i++ { + lns[i], err = net.Listen("tcp", "127.0.0.1:0") // nolint: noctx + assert.NoError(t, err) + + go func(j int) { + conn, connErr := lns[j].Accept() + assert.NoError(t, connErr) + + mu.Lock() + acceptedConns[j] = conn + mu.Unlock() + + if j == 2 { + cancel() + } + }(i) + } + + manager, err := newTestManager() + assert.NoError(t, err) + + turnSocket, err := net.ListenPacket("udp4", "0.0.0.0:0") // nolint: noctx + assert.NoError(t, err) + + fiveTuple := randomFiveTuple() + allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoTCP, 0, proto.DefaultLifetime, "", "") + assert.NoError(t, err) + allocation.RelayAddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(60999-32768+1) + 32768} //nolint:gosec + + for i := 0; i < 3; i++ { + addr, ok := lns[i].Addr().(*net.TCPAddr) + assert.True(t, ok) + peer := proto.PeerAddress{IP: addr.IP, Port: addr.Port} + + connectionID, err := manager.CreateTCPConnection(allocation, peer) + assert.NoError(t, err) + assert.NotZero(t, connectionID) + + conn := manager.GetTCPConnection("", connectionID) + assert.NotNil(t, conn) + + laddr, ok := conn.LocalAddr().(*net.TCPAddr) + assert.True(t, ok) + relayAddr, ok := allocation.RelayAddr.(*net.TCPAddr) + assert.True(t, ok) + assert.Equal(t, laddr.IP.String(), relayAddr.IP.String()) + assert.Equal(t, laddr.IP.String(), relayAddr.IP.String()) + + assert.NoError(t, conn.Close()) + } + + <-ctx.Done() + + assert.NoError(t, turnSocket.Close()) + + mu.Lock() + defer mu.Unlock() + for i := 0; i < 3; i++ { + assert.NoError(t, acceptedConns[i].Close()) + assert.NoError(t, lns[i].Close()) + } +} + +func TestCreateTCPConnectionDuplicateTCPConn(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx assert.NoError(t, err) @@ -261,9 +348,10 @@ func TestCreateTCPConnection(t *testing.T) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoUDP, 0, proto.DefaultLifetime, "", "") + allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoTCP, 0, proto.DefaultLifetime, "", "") assert.NoError(t, err) + allocation.RelayAddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(60999-32768+1) + 32768} //nolint:gosec connectionID, err := manager.CreateTCPConnection(allocation, peer) assert.NoError(t, err) assert.NotZero(t, connectionID) @@ -271,14 +359,6 @@ func TestCreateTCPConnection(t *testing.T) { _, err = manager.CreateTCPConnection(allocation, peer) assert.ErrorIs(t, err, ErrDupeTCPConnection) - assert.Nil(t, manager.GetTCPConnection("bad-username", connectionID)) - c1 := manager.GetTCPConnection("", connectionID) - - assert.Nil(t, manager.GetTCPConnection("", connectionID)) - - assert.NotNil(t, c1) - assert.NoError(t, c1.Close()) - <-ctx.Done() assert.NoError(t, acceptErr) assert.NoError(t, acceptedConn.Close()) @@ -294,8 +374,9 @@ func TestCreateTCPConnectionInvalidPeerAddress(t *testing.T) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoUDP, 0, proto.DefaultLifetime, "", "") + allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoTCP, 0, proto.DefaultLifetime, "", "") assert.NoError(t, err) + allocation.RelayAddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(60999-32768+1) + 32768} //nolint:gosec _, err = manager.CreateTCPConnection(allocation, proto.PeerAddress{IP: nil, Port: 1234}) assert.ErrorIs(t, err, errInvalidPeerAddress) @@ -314,8 +395,9 @@ func TestCreateTCPConnectionInvalid(t *testing.T) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoUDP, 0, proto.DefaultLifetime, "", "") + allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoTCP, 0, proto.DefaultLifetime, "", "") assert.NoError(t, err) + allocation.RelayAddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(60999-32768+1) + 32768} //nolint:gosec peerAddress := proto.PeerAddress{IP: net.ParseIP("127.0.0.1"), Port: 5000} @@ -353,8 +435,9 @@ func TestCreateTCPConnectionTimeout(t *testing.T) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoUDP, 0, proto.DefaultLifetime, "", "") + allocation, err := manager.CreateAllocation(fiveTuple, turnSocket, proto.ProtoTCP, 0, proto.DefaultLifetime, "", "") assert.NoError(t, err) + allocation.RelayAddr = &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(60999-32768+1) + 32768} //nolint:gosec connectionID, err := manager.CreateTCPConnection(allocation, peer) assert.NoError(t, err) diff --git a/internal/allocation/allocation_test.go b/internal/allocation/allocation_test.go index 75bf8108..a99248f6 100644 --- a/internal/allocation/allocation_test.go +++ b/internal/allocation/allocation_test.go @@ -18,6 +18,7 @@ import ( "github.com/pion/logging" "github.com/pion/stun/v3" + "github.com/pion/transport/v4/reuseport" "github.com/pion/turn/v4/internal/ipnet" "github.com/pion/turn/v4/internal/proto" "github.com/stretchr/testify/assert" @@ -321,11 +322,20 @@ func TestTCPRelay_E2E(t *testing.T) { return nil, nil, nil }, AllocateListener: func(string, int) (net.Listener, net.Addr, error) { - ln, listenerErr := net.Listen("tcp4", "127.0.0.1:0") // nolint: noctx + ln, listenerErr := (&net.ListenConfig{Control: reuseport.Control}). + Listen(context.TODO(), "tcp4", "127.0.0.1:0") assert.NoError(t, listenerErr) return ln, ln.Addr(), nil }, + AllocateConn: func(network string, laddr, raddr net.Addr) (net.Conn, error) { + dialer := net.Dialer{ + LocalAddr: laddr, + Control: reuseport.Control, + } + + return dialer.Dial(network, raddr.String()) + }, }) assert.NoError(t, err) diff --git a/internal/allocation/errors.go b/internal/allocation/errors.go index 8866900b..04ed97d6 100644 --- a/internal/allocation/errors.go +++ b/internal/allocation/errors.go @@ -11,6 +11,7 @@ var ( errAllocatePacketConnMustBeSet = errors.New("AllocatePacketConn must be set") errAllocateListenerMustBeSet = errors.New("AllocateListener must be set") + errAllocateConnMustBeSet = errors.New("AllocateConn must be set") errLeveledLoggerMustBeSet = errors.New("LeveledLogger must be set") errSameChannelDifferentPeer = errors.New("you cannot use the same channel number with different peer") errNilFiveTuple = errors.New("allocations must not be created with nil FivTuple") diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index 4cbf1899..d40e28fd 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -72,6 +72,9 @@ func TestAllocationLifeTime(t *testing.T) { AllocateListener: func(string, int) (net.Listener, net.Addr, error) { return nil, nil, nil }, + AllocateConn: func(network string, laddr, raddr net.Addr) (net.Conn, error) { + return nil, nil //nolint:nilnil + }, LeveledLogger: logger, }) assert.NoError(t, err) @@ -131,6 +134,9 @@ func TestRequestedTransport(t *testing.T) { AllocateListener: func(string, int) (net.Listener, net.Addr, error) { return nil, nil, nil }, + AllocateConn: func(network string, laddr, raddr net.Addr) (net.Conn, error) { + return nil, nil //nolint:nilnil + }, LeveledLogger: logger, }) assert.NoError(t, err) @@ -183,6 +189,9 @@ func TestConnectRequest(t *testing.T) { AllocateListener: func(string, int) (net.Listener, net.Addr, error) { return nil, nil, nil }, + AllocateConn: func(network string, _, raddr net.Addr) (net.Conn, error) { + return net.Dial("tcp4", raddr.String()) // nolint: noctx + }, LeveledLogger: logger, }) assert.NoError(t, err) @@ -266,6 +275,9 @@ func TestConnectionBindRequest(t *testing.T) { AllocateListener: func(string, int) (net.Listener, net.Addr, error) { return nil, nil, nil }, + AllocateConn: func(network string, _, raddr net.Addr) (net.Conn, error) { + return net.Dial("tcp4", raddr.String()) // nolint: noctx + }, LeveledLogger: logger, }) assert.NoError(t, err) diff --git a/relay_address_generator_none.go b/relay_address_generator_none.go index dd67e549..e521c2bf 100644 --- a/relay_address_generator_none.go +++ b/relay_address_generator_none.go @@ -4,11 +4,13 @@ package turn import ( + "context" "fmt" "net" "strconv" "github.com/pion/transport/v4" + "github.com/pion/transport/v4/reuseport" "github.com/pion/transport/v4/stdnet" ) @@ -68,10 +70,27 @@ func (r *RelayAddressGeneratorNone) AllocateListener(network string, requestedPo return nil, nil, err } - ln, err := r.Net.ListenTCP(network, tcpAddr) // nolint: noctx + listenConfig := r.Net.CreateListenConfig(&net.ListenConfig{ + // Enable SO_REUSEADDR and SO_REUSEPORT where needed to let multiple connnections + // bind to the same relay address. + Control: reuseport.Control, + }) + ln, err := listenConfig.Listen(context.TODO(), network, tcpAddr.String()) if err != nil { return nil, nil, err } return ln, ln.Addr(), nil } + +// AllocateConn creates a new outgoing TCP connection bound to the relay address to send traffic to a peer. +func (r *RelayAddressGeneratorNone) AllocateConn(network string, laddr, raddr net.Addr) (net.Conn, error) { + dialer := r.Net.CreateDialer(&net.Dialer{ + LocalAddr: laddr, + // Enable SO_REUSEADDR and SO_REUSEPORT where needed to let multiple connnections + // bind to the same relay address. + Control: reuseport.Control, + }) + + return dialer.Dial(network, raddr.String()) +} diff --git a/relay_address_generator_range.go b/relay_address_generator_range.go index 17f01de9..e06700cc 100644 --- a/relay_address_generator_range.go +++ b/relay_address_generator_range.go @@ -4,11 +4,13 @@ package turn import ( + "context" "fmt" "net" "github.com/pion/randutil" "github.com/pion/transport/v4" + "github.com/pion/transport/v4/reuseport" "github.com/pion/transport/v4/stdnet" ) @@ -123,13 +125,18 @@ func (r *RelayAddressGeneratorPortRange) AllocateListener( // nolint: cyclop } } + listenConfig := r.Net.CreateListenConfig(&net.ListenConfig{ + // Enable SO_REUSEADDR and SO_REUSEPORT where needed to let multiple connnections + // bind to the same relay address. + Control: reuseport.Control, + }) listen := func(port int) (net.Listener, net.Addr, error) { tcpAddr, err := r.Net.ResolveTCPAddr(network, fmt.Sprintf("%s:%d", r.Address, port)) if err != nil { return nil, nil, err } - ln, err := r.Net.ListenTCP(network, tcpAddr) // nolint: noctx + ln, err := listenConfig.Listen(context.TODO(), network, tcpAddr.String()) if err != nil { return nil, nil, err } @@ -162,3 +169,15 @@ func (r *RelayAddressGeneratorPortRange) AllocateListener( // nolint: cyclop return nil, nil, errMaxRetriesExceeded } + +// AllocateConn creates a new outgoing TCP connection bound to the relay address to send traffic to a peer. +func (r *RelayAddressGeneratorPortRange) AllocateConn(network string, laddr, raddr net.Addr) (net.Conn, error) { + dialer := r.Net.CreateDialer(&net.Dialer{ + LocalAddr: laddr, + // Enable SO_REUSEADDR and SO_REUSEPORT where needed to let multiple connnections + // bind to the same relay address. + Control: reuseport.Control, + }) + + return dialer.Dial(network, raddr.String()) +} diff --git a/relay_address_generator_static.go b/relay_address_generator_static.go index 4b53d728..3614430a 100644 --- a/relay_address_generator_static.go +++ b/relay_address_generator_static.go @@ -4,11 +4,13 @@ package turn import ( + "context" "fmt" "net" "strconv" "github.com/pion/transport/v4" + "github.com/pion/transport/v4/reuseport" "github.com/pion/transport/v4/stdnet" ) @@ -82,7 +84,12 @@ func (r *RelayAddressGeneratorStatic) AllocateListener(network string, requested return nil, nil, err } - ln, err := r.Net.ListenTCP(network, tcpAddr) // nolint: noctx + listenConfig := r.Net.CreateListenConfig(&net.ListenConfig{ + // Enable SO_REUSEADDR and SO_REUSEPORT where needed to let multiple connnections + // bind to the same relay address. + Control: reuseport.Control, + }) + ln, err := listenConfig.Listen(context.TODO(), network, tcpAddr.String()) if err != nil { return nil, nil, err } @@ -99,3 +106,15 @@ func (r *RelayAddressGeneratorStatic) AllocateListener(network string, requested return ln, relayAddr, nil } + +// AllocateConn creates a new outgoing TCP connection bound to the relay address to send traffic to a peer. +func (r *RelayAddressGeneratorStatic) AllocateConn(network string, laddr, raddr net.Addr) (net.Conn, error) { + dialer := r.Net.CreateDialer(&net.Dialer{ + LocalAddr: laddr, + // Enable SO_REUSEADDR and SO_REUSEPORT where needed to let multiple connnections + // bind to the same relay address. + Control: reuseport.Control, + }) + + return dialer.Dial(network, raddr.String()) +} diff --git a/server.go b/server.go index 10a369e5..c17c7d75 100644 --- a/server.go +++ b/server.go @@ -205,6 +205,10 @@ func (n *nilAddressGenerator) AllocateListener(string, int) (net.Listener, net.A return nil, nil, errRelayAddressGeneratorNil } +func (n *nilAddressGenerator) AllocateConn(network string, laddr, raddr net.Addr) (net.Conn, error) { + return nil, errRelayAddressGeneratorNil +} + func (s *Server) createAllocationManager( addrGenerator RelayAddressGenerator, handler PermissionHandler, @@ -219,6 +223,7 @@ func (s *Server) createAllocationManager( am, err := allocation.NewManager(allocation.ManagerConfig{ AllocatePacketConn: addrGenerator.AllocatePacketConn, AllocateListener: addrGenerator.AllocateListener, + AllocateConn: addrGenerator.AllocateConn, PermissionHandler: handler, EventHandler: s.eventHandler, LeveledLogger: s.log, diff --git a/server_config.go b/server_config.go index 8b551cb1..c0f4102b 100644 --- a/server_config.go +++ b/server_config.go @@ -26,6 +26,9 @@ type RelayAddressGenerator interface { // Allocate a Listener (TCP) RelayAddress AllocateListener(network string, requestedPort int) (net.Listener, net.Addr, error) + + // Allocate a Conn (TCP) RelayAddress + AllocateConn(network string, laddr, raddr net.Addr) (net.Conn, error) } // PermissionHandler is a callback to filter incoming CreatePermission and ChannelBindRequest diff --git a/server_test.go b/server_test.go index b5b11053..21017f0c 100644 --- a/server_test.go +++ b/server_test.go @@ -11,12 +11,12 @@ import ( "fmt" "net" "sync/atomic" - "syscall" "testing" "time" "github.com/pion/logging" "github.com/pion/stun/v3" + "github.com/pion/transport/v4/reuseport" "github.com/pion/transport/v4/test" "github.com/pion/transport/v4/vnet" "github.com/pion/turn/v4/internal/allocation" @@ -302,13 +302,7 @@ func TestServer(t *testing.T) { //nolint:maintidx assert.NoError(t, err) // make sure we can reuse the client port - dialer := &net.Dialer{ - Control: func(_, _ string, conn syscall.RawConn) error { - return conn.Control(func(descriptor uintptr) { - _ = syscall.SetsockoptInt(Handle(descriptor), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - }) - }, - } + dialer := &net.Dialer{Control: reuseport.Control} conn, err := dialer.Dial("tcp", testAddr) assert.NoError(t, err) @@ -365,6 +359,10 @@ func TestServer(t *testing.T) { //nolint:maintidx udpListener, err := net.ListenPacket("udp4", testAddr) // nolint: noctx assert.NoError(t, err) + // Enforce correct client IP and port + clientConn, err := net.ListenPacket("udp4", "127.0.0.1:0") // nolint: noctx + assert.NoError(t, err) + server, err := NewServer(ServerConfig{ AuthHandler: func(ra *RequestAttributes) (key []byte, ok bool) { if pw, ok := credMap[ra.Username]; ok { @@ -381,7 +379,7 @@ func TestServer(t *testing.T) { //nolint:maintidx Address: "127.0.0.1", }, PermissionHandler: func(src net.Addr, peer net.IP) bool { - return src.String() == "127.0.0.1:54321" && + return src.String() == clientConn.LocalAddr().String() && peer.Equal(net.IPv4(127, 0, 0, 4)) }, }, @@ -391,14 +389,10 @@ func TestServer(t *testing.T) { //nolint:maintidx }) assert.NoError(t, err) - // Enforce correct client IP and port - conn, err := net.ListenPacket("udp4", "127.0.0.1:54321") // nolint: noctx - assert.NoError(t, err) - client, err := NewClient(&ClientConfig{ STUNServerAddr: testAddr, TURNServerAddr: testAddr, - Conn: conn, + Conn: clientConn, Username: "user", Password: "pass", Realm: "pion.ly", @@ -456,7 +450,7 @@ func TestServer(t *testing.T) { //nolint:maintidx assert.NoError(t, relayConn.Close()) client.Close() - assert.NoError(t, conn.Close()) + assert.NoError(t, clientConn.Close()) // Enforce filtered source address conn2, err := net.ListenPacket("udp4", "127.0.0.1:12321") // nolint: noctx