diff --git a/internal/server/server.go b/internal/server/server.go index 9cc19199..0824fcef 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -27,7 +27,8 @@ type Request struct { NonceHash *NonceHash // User Configuration - AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + QuotaHandler func(username string, realm string, srcAddr net.Addr) (ok bool) Log logging.LeveledLogger Realm string diff --git a/internal/server/turn.go b/internal/server/turn.go index 5367b82d..d312ade4 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -127,6 +127,10 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // server is free to define this allocation quota any way it wishes, // but SHOULD define it based on the username used to authenticate // the request, and not on the client's transport address. + if r.QuotaHandler != nil && !r.QuotaHandler(usernameAttr.String(), realmAttr.String(), r.SrcAddr) { + quotaReachedMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeAllocQuotaReached}) + return buildAndSend(r.Conn, r.SrcAddr, quotaReachedMsg...) + } // 8. Also at any point, the server MAY choose to reject the request // with a 300 (Try Alternate) error if it wishes to redirect the diff --git a/server.go b/server.go index c5ca1ed3..32543bc4 100644 --- a/server.go +++ b/server.go @@ -24,6 +24,7 @@ const ( type Server struct { log logging.LeveledLogger authHandler AuthHandler + quotaHandler QuotaHandler realm string channelBindTimeout time.Duration nonceHash *server.NonceHash @@ -61,6 +62,7 @@ func NewServer(config ServerConfig) (*Server, error) { s := &Server{ log: loggerFactory.NewLogger("turn"), authHandler: config.AuthHandler, + quotaHandler: config.QuotaHandler, realm: config.Realm, channelBindTimeout: config.ChannelBindTimeout, packetConnConfigs: config.PacketConnConfigs, @@ -224,6 +226,7 @@ func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manage Buff: buf[:n], Log: s.log, AuthHandler: s.authHandler, + QuotaHandler: s.quotaHandler, Realm: s.realm, AllocationManager: allocationManager, ChannelBindTimeout: s.channelBindTimeout, diff --git a/server_config.go b/server_config.go index 791ed8aa..7a070290 100644 --- a/server_config.go +++ b/server_config.go @@ -181,6 +181,9 @@ func genericEventHandler(handlers EventHandlers) allocation.EventHandler { } } +// QuotaHandler is a callback allows allocations to be rejected when a per-user quota is exceeded. If the callback returns true the allocation request is accepted, otherwise it is rejected and a 486 (Allocation Quota Reached) error is returned to the user. +type QuotaHandler func(username, realm string, srcAddr net.Addr) (ok bool) + // ServerConfig configures the Pion TURN Server type ServerConfig struct { // PacketConnConfigs and ListenerConfigs are a list of all the turn listeners @@ -197,6 +200,9 @@ type ServerConfig struct { // AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior AuthHandler AuthHandler + // AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior + QuotaHandler QuotaHandler + // EventHandlers is a set of callbacks for tracking allocation lifecycle. EventHandlers EventHandlers diff --git a/server_test.go b/server_test.go index d34dce46..896be267 100644 --- a/server_test.go +++ b/server_test.go @@ -1072,6 +1072,58 @@ func TestSTUNOnly(t *testing.T) { assert.Equal(t, err.Error(), "Allocate error response (error 400: )") } +func TestQuotaReached(t *testing.T) { + serverAddr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:3478") + assert.NoError(t, err) + + serverConn, err := net.ListenPacket(serverAddr.Network(), serverAddr.String()) + assert.NoError(t, err) + + defer serverConn.Close() //nolint:errcheck + + credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")} + server, err := NewServer(ServerConfig{ + AuthHandler: func(username, _ string, _ net.Addr) (key []byte, ok bool) { + if pw, ok := credMap[username]; ok { + return pw, true + } + return nil, false + }, + QuotaHandler: func(_, _ string, _ net.Addr) (ok bool) { return false }, + Realm: "pion.ly", + PacketConnConfigs: []PacketConnConfig{{ + PacketConn: serverConn, + RelayAddressGenerator: &RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP("127.0.0.1"), + Address: "0.0.0.0", + }, + }}, + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + + defer server.Close() //nolint:errcheck + + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err) + + client, err := NewClient(&ClientConfig{ + Conn: conn, + STUNServerAddr: "127.0.0.1:3478", + TURNServerAddr: "127.0.0.1:3478", + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + assert.NoError(t, client.Listen()) + defer client.Close() + + _, err = client.Allocate() + assert.Equal(t, err.Error(), "Allocate error response (error 486: )") +} + func RunBenchmarkServer(b *testing.B, clientNum int) { loggerFactory := logging.NewDefaultLoggerFactory() credMap := map[string][]byte{