Skip to content

Commit

Permalink
Add a callback to handle allocation requests
Browse files Browse the repository at this point in the history
  • Loading branch information
rg0now committed Apr 24, 2023
1 parent f880e55 commit 70c67df
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 21 deletions.
30 changes: 30 additions & 0 deletions internal/allocation/allocation_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ import (
"time"

"github.com/pion/logging"
"github.com/pion/stun"
)

// ManagerConfig a bag of config params for Manager.
type ManagerConfig struct {
LeveledLogger logging.LeveledLogger
AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error)
AllocationHandler func(clientAddr net.Addr) (alternateServer net.Addr, errorCode stun.ErrorCode)
PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
}

Expand All @@ -35,6 +37,7 @@ type Manager struct {

allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error)
allocationHandler func(clientAddr net.Addr) (alternateServer net.Addr, errorCode stun.ErrorCode)
permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
}

Expand All @@ -54,6 +57,7 @@ func NewManager(config ManagerConfig) (*Manager, error) {
allocations: make(map[string]*Allocation, 64),
allocatePacketConn: config.AllocatePacketConn,
allocateConn: config.AllocateConn,
allocationHandler: config.AllocationHandler,
permissionHandler: config.PermissionHandler,
}, nil
}
Expand Down Expand Up @@ -85,6 +89,32 @@ func (m *Manager) Close() error {
return nil
}

// HandleAllocation calls the allocation handler callback to decide whether to admit a client request.
func (m *Manager) HandleAllocation(clientAddr net.Addr) (net.IP, int, stun.ErrorCode) {
if m.allocationHandler == nil {
return nil, 0, 0
}

altServer, errorCode := m.allocationHandler(clientAddr)

var altIP net.IP
var altPort int
switch addr := altServer.(type) {
case *net.UDPAddr:
altIP = addr.IP
altPort = addr.Port
case *net.TCPAddr:
altIP = addr.IP
altPort = addr.Port
default:
m.log.Warnf("received unknown alternate server address from allocation handler: %s",
altServer.String())
return nil, 0, stun.CodeServerError
}

return altIP, altPort, errorCode
}

// CreateAllocation creates a new allocation and starts relaying
func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) {
switch {
Expand Down
1 change: 1 addition & 0 deletions internal/server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ var (
errShortWrite = errors.New("packet write smaller than packet")
errNoSuchChannelBind = errors.New("no such channel bind")
errFailedWriteSocket = errors.New("failed writing to socket")
errFailedToSetAlternateServer = errors.New("cannot add ALTERNATE-SERVER attribute")
)
26 changes: 26 additions & 0 deletions internal/server/turn.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,32 @@ func handleAllocateRequest(r Request, m *stun.Message) error {
// with a 300 (Try Alternate) error if it wishes to redirect the
// client to a different server. The use of this error code and
// attribute follow the specification in [RFC5389].
if altIP, altPort, errorCode := r.AllocationManager.HandleAllocation(r.SrcAddr); errorCode != 0 {
r.Log.Debugf("allocation handler sending error code %d to client %s", errorCode, r.SrcAddr.String())

msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse))
if err = errorCode.AddTo(m); err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, err, msg...)
}

if errorCode == stun.CodeTryAlternate && altIP != nil && altPort != 0 {
addr := &stun.AlternateServer{
IP: altIP,
Port: altPort,
}
if err = addr.AddTo(m); err != nil {
return buildAndSendErr(r.Conn, r.SrcAddr, errFailedToSetAlternateServer, msg...)
}

r.Log.Debugf("redirecting client to %s:%d", addr.IP.String(), addr.Port)
}

return buildAndSend(r.Conn, r.SrcAddr, msg...)
}

// If all the checks pass, the server creates the allocation. The
// 5-tuple is set to the 5-tuple from the Allocate request, while the
// list of permissions and the list of channels are initially empty.
lifetimeDuration := allocationLifeTime(m)
a, err := r.AllocationManager.CreateAllocation(
fiveTuple,
Expand Down
15 changes: 6 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func NewServer(config ServerConfig) (*Server, error) {
}

for _, cfg := range s.packetConnConfigs {
am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler)
am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.AllocationHandler, cfg.PermissionHandler)
if err != nil {
return nil, fmt.Errorf("failed to create AllocationManager: %w", err)
}
Expand All @@ -84,7 +84,7 @@ func NewServer(config ServerConfig) (*Server, error) {
}

for _, cfg := range s.listenerConfigs {
am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler)
am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.AllocationHandler, cfg.PermissionHandler)
if err != nil {
return nil, fmt.Errorf("failed to create AllocationManager: %w", err)
}
Expand All @@ -101,7 +101,7 @@ func NewServer(config ServerConfig) (*Server, error) {
return s, nil
}

// AllocationCount returns the number of active allocations. It can be used to drain the server before closing
// AllocationCount returns the number of active allocations. It can be used to drain the server before closing.
func (s *Server) AllocationCount() int {
allocs := 0
for _, am := range s.allocationManagers {
Expand Down Expand Up @@ -156,15 +156,12 @@ func (s *Server) readListener(l net.Listener, am *allocation.Manager) {
}
}

func (s *Server) createAllocationManager(addrGenerator RelayAddressGenerator, handler PermissionHandler) (*allocation.Manager, error) {
if handler == nil {
handler = DefaultPermissionHandler
}

func (s *Server) createAllocationManager(addrGenerator RelayAddressGenerator, allocationHandler AllocationHandler, permissionHandler PermissionHandler) (*allocation.Manager, error) {
am, err := allocation.NewManager(allocation.ManagerConfig{
AllocatePacketConn: addrGenerator.AllocatePacketConn,
AllocateConn: addrGenerator.AllocateConn,
PermissionHandler: handler,
AllocationHandler: allocationHandler,
PermissionHandler: permissionHandler,
LeveledLogger: s.log,
})
if err != nil {
Expand Down
35 changes: 23 additions & 12 deletions server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/pion/logging"
"github.com/pion/stun"
)

// RelayAddressGenerator is used to generate a RelayAddress when creating an allocation.
Expand All @@ -26,6 +27,15 @@ type RelayAddressGenerator interface {
AllocateConn(network string, requestedPort int) (net.Conn, net.Addr, error)
}

// AllocationHandler is a callback used to handle incoming allocation requests, allowing users to
// customize Pion TURN with custom behavior. If the returned error code is nonzero then the request
// is rejected with the given error code. This is useful to, e.g., return an "Allocation Quota
// Reached" when the number of allocations from the client address surpasses a limit. If the error
// code is "Try Alternate" then the reject response will also contain an ALTERNATE-SERVER attribute
// with the returned alternate server address. This is useful to redirect the client to another
// TURN server.
type AllocationHandler func(clientAddr net.Addr) (alternateServer net.Addr, errorCode stun.ErrorCode)

// PermissionHandler is a callback to filter incoming CreatePermission and ChannelBindRequest
// requests based on the client IP address and port and the peer IP address the client intends to
// connect to. If the client is behind a NAT then the filter acts on the server reflexive
Expand All @@ -34,11 +44,6 @@ type RelayAddressGenerator interface {
// of NATs that comply with [RFC4787], see https://tools.ietf.org/html/rfc5766#section-2.3.
type PermissionHandler func(clientAddr net.Addr, peerIP net.IP) (ok bool)

// DefaultPermissionHandler is convince function that grants permission to all peers
func DefaultPermissionHandler(net.Addr, net.IP) (ok bool) {
return true
}

// PacketConnConfig is a single net.PacketConn to listen/write on. This will be used for UDP listeners
type PacketConnConfig struct {
PacketConn net.PacketConn
Expand All @@ -47,9 +52,12 @@ type PacketConnConfig struct {
// creates the net.PacketConn and returns the IP/Port it is available at
RelayAddressGenerator RelayAddressGenerator

// PermissionHandler is a callback to filter peer addresses. Can be set as nil, in which
// case the DefaultPermissionHandler is automatically instantiated to admit all peer
// connections
// AllocationHandler is a callback to filter client addresses or redirect clients to an
// alternate server.
AllocationHandler AllocationHandler

// PermissionHandler is a callback to filter peer addresses. Specifying no permission
// handler will admit all peer connections.
PermissionHandler PermissionHandler
}

Expand All @@ -72,9 +80,12 @@ type ListenerConfig struct {
// creates the net.PacketConn and returns the IP/Port it is available at
RelayAddressGenerator RelayAddressGenerator

// PermissionHandler is a callback to filter peer addresses. Can be set as nil, in which
// case the DefaultPermissionHandler is automatically instantiated to admit all peer
// connections
// AllocationHandler is a callback to filter client addresses or redirect clients to an
// alternate server.
AllocationHandler AllocationHandler

// PermissionHandler is a callback to filter peer addresses. Specifying no permission
// handler will admit all peer connections.
PermissionHandler PermissionHandler
}

Expand Down Expand Up @@ -114,7 +125,7 @@ type ServerConfig struct {
// Realm sets the realm for this server
Realm string

// AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior
// AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior.
AuthHandler AuthHandler

// ChannelBindTimeout sets the lifetime of channel binding. Defaults to 10 minutes.
Expand Down
62 changes: 62 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/v2/test"
"github.com/pion/transport/v2/vnet"
"github.com/pion/turn/v2/internal/proto"
Expand All @@ -27,6 +28,7 @@ func TestServer(t *testing.T) {
defer report()

loggerFactory := logging.NewDefaultLoggerFactory()
loggerFactory.DefaultLogLevel = logging.LogLevelTrace

credMap := map[string][]byte{
"user": GenerateAuthKey("user", "pion.ly", "pass"),
Expand Down Expand Up @@ -119,6 +121,66 @@ func TestServer(t *testing.T) {
assert.NoError(t, server.Close())
})

t.Run("redirect", func(t *testing.T) {
udpListener, err := net.ListenPacket("udp4", "0.0.0.0:3478")
assert.NoError(t, err)

server, err := NewServer(ServerConfig{
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
if pw, ok := credMap[username]; ok {
return pw, true
}
return nil, false
},
PacketConnConfigs: []PacketConnConfig{
{
PacketConn: udpListener,
AllocationHandler: func(clientAddr net.Addr) (alternateServer net.Addr, errorCode stun.ErrorCode) {
return &net.UDPAddr{
IP: net.ParseIP("1.2.3.4"),
Port: 8743,
}, stun.CodeTryAlternate
},
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
},
},
Realm: "pion.ly",
LoggerFactory: loggerFactory,
})
assert.NoError(t, err)

conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
assert.NoError(t, err)

serverAddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:3478")
assert.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: conn,
STUNServerAddr: serverAddr,
TURNServerAddr: serverAddr,
Username: "user",
Password: "pass",
Realm: "pion.ly",
LoggerFactory: loggerFactory,
})
assert.NoError(t, err)
assert.NoError(t, client.Listen())

_, err = client.Allocate()
assert.Error(t, err, "should return error")

fmt.Printf("%#v\n", err)

client.Close()
assert.NoError(t, conn.Close())

assert.NoError(t, server.Close())
})

t.Run("Filter on client address and peer IP", func(t *testing.T) {
udpListener, err := net.ListenPacket("udp4", "0.0.0.0:3478")
assert.NoError(t, err)
Expand Down

0 comments on commit 70c67df

Please sign in to comment.