Skip to content

Commit

Permalink
add tests for udpmux
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 8, 2023
1 parent 2144049 commit 5579267
Showing 1 changed file with 202 additions and 64 deletions.
266 changes: 202 additions & 64 deletions p2p/transport/webrtc/udpmux/mux_test.go
Original file line number Diff line number Diff line change
@@ -1,89 +1,227 @@
package udpmux

import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"

"github.com/pion/stun"
"github.com/stretchr/testify/require"
)

var _ net.PacketConn = dummyPacketConn{}

type dummyPacketConn struct{}

// Close implements net.PacketConn
func (dummyPacketConn) Close() error {
return nil
}

// LocalAddr implements net.PacketConn
func (dummyPacketConn) LocalAddr() net.Addr {
return nil
}

// ReadFrom implements net.PacketConn
func (dummyPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, &net.UDPAddr{}, nil
}

// SetDeadline implements net.PacketConn
func (dummyPacketConn) SetDeadline(t time.Time) error {
return nil
func getSTUNBindingRequest(ufrag string) *stun.Message {
msg := stun.New()
msg.SetType(stun.BindingRequest)
uattr := stun.RawAttribute{
Type: stun.AttrUsername,
Value: []byte(fmt.Sprintf("%s:%s", ufrag, ufrag)), // This is the format we expect in our connections
}
uattr.AddTo(msg)
msg.Encode()
return msg
}

// SetReadDeadline implements net.PacketConn
func (dummyPacketConn) SetReadDeadline(t time.Time) error {
return nil
func setupMapping(t *testing.T, ufrag string, from net.PacketConn, m *UDPMux) {
t.Helper()
msg := getSTUNBindingRequest(ufrag)
_, err := from.WriteTo(msg.Raw, m.GetListenAddresses()[0])
require.NoError(t, err)
}

// SetWriteDeadline implements net.PacketConn
func (dummyPacketConn) SetWriteDeadline(t time.Time) error {
return nil
func newPacketConn(t *testing.T) net.PacketConn {
t.Helper()
udpPort0 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
c, err := net.ListenUDP("udp", udpPort0)
require.NoError(t, err)
t.Cleanup(func() { c.Close() })
return c
}

// WriteTo implements net.PacketConn
func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, nil
func TestAccept(t *testing.T) {
c := newPacketConn(t)
defer c.Close()
m := NewUDPMux(c)
m.Start()
defer m.Close()

ufrags := []string{"a", "b", "c", "d"}
conns := make([]net.PacketConn, len(ufrags))
for i, ufrag := range ufrags {
conns[i] = newPacketConn(t)
setupMapping(t, ufrag, conns[i], m)
}
for i, ufrag := range ufrags {
c, err := m.Accept(context.Background())
require.NoError(t, err)
require.Equal(t, c.Ufrag, ufrag)
require.Equal(t, c.Addr, conns[i].LocalAddr())
}

for i, ufrag := range ufrags {
// should not be accepted
setupMapping(t, ufrag, conns[i], m)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := m.Accept(ctx)
require.Error(t, err)

// should not be accepted
cc := newPacketConn(t)
setupMapping(t, ufrag, cc, m)
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = m.Accept(ctx)
require.Error(t, err)
}
}

func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool {
m.mx.Lock()
_, ok := m.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}]
m.mx.Unlock()
return ok
func TestGetConn(t *testing.T) {
c := newPacketConn(t)
m := NewUDPMux(c)
m.Start()
defer m.Close()

ufrags := []string{"a", "b", "c", "d"}
conns := make([]net.PacketConn, len(ufrags))
for i, ufrag := range ufrags {
conns[i] = newPacketConn(t)
setupMapping(t, ufrag, conns[i], m)
}
for i, ufrag := range ufrags {
c, err := m.Accept(context.Background())
require.NoError(t, err)
require.Equal(t, c.Ufrag, ufrag)
require.Equal(t, c.Addr, conns[i].LocalAddr())
}

for i, ufrag := range ufrags {
c, err := m.GetConn(ufrag, conns[i].LocalAddr())
require.NoError(t, err)
msg := make([]byte, 100)
_, _, err = c.ReadFrom(msg)
require.NoError(t, err)
}

for i, ufrag := range ufrags {
cc := newPacketConn(t)
// setupMapping of cc to ufrags[0] and remove the stun binding request from the queue
setupMapping(t, ufrag, cc, m)
mc, err := m.GetConn(ufrag, cc.LocalAddr())
require.NoError(t, err)
msg := make([]byte, 100)
_, _, err = mc.ReadFrom(msg)
require.NoError(t, err)

// Write from new connection should provide the new address on ReadFrom
_, err = cc.WriteTo([]byte("test1"), c.LocalAddr())
require.NoError(t, err)
n, addr, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addr, cc.LocalAddr())
require.Equal(t, string(msg[:n]), "test1")

// Write from original connection should provide the original address
_, err = conns[i].WriteTo([]byte("test2"), c.LocalAddr())
require.NoError(t, err)
n, addr, err = mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addr, conns[i].LocalAddr())
require.Equal(t, string(msg[:n]), "test2")
}
}

var (
addrV4 = net.UDPAddr{IP: net.IPv4zero, Port: 1234}
addrV6 = net.UDPAddr{IP: net.IPv6zero, Port: 1234}
)

func TestUDPMux_GetConn(t *testing.T) {
m := NewUDPMux(dummyPacketConn{})
require.False(t, hasConn(m, "test", false))
conn, err := m.GetConn("test", &addrV4)
func TestRemoveConnByUfrag(t *testing.T) {
c := newPacketConn(t)
m := NewUDPMux(c)
m.Start()
defer m.Close()

// Map each ufrag to two addresses
ufrag := "a"
count := 10
conns := make([]net.PacketConn, count)
for i := 0; i < 10; i++ {
conns[i] = newPacketConn(t)
setupMapping(t, ufrag, conns[i], m)
}
mc, err := m.GetConn(ufrag, conns[0].LocalAddr())
require.NoError(t, err)
require.NotNil(t, conn)

require.False(t, hasConn(m, "test", true))
connv6, err := m.GetConn("test", &addrV6)
require.NoError(t, err)
require.NotNil(t, connv6)

require.NotEqual(t, conn, connv6)
}

func TestUDPMux_RemoveConnectionOnClose(t *testing.T) {
mux := NewUDPMux(dummyPacketConn{})
conn, err := mux.GetConn("test", &addrV4)
for i := 0; i < 10; i++ {
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr())
require.NoError(t, err)
require.Equal(t, mc1, mc)
}

// Now remove the ufrag
m.RemoveConnByUfrag(ufrag)

// All connections should now be associated with b
ufrag = "b"
for i := 0; i < 10; i++ {
setupMapping(t, ufrag, conns[i], m)
}
mc, err = m.GetConn(ufrag, conns[0].LocalAddr())
require.NoError(t, err)
require.NotNil(t, conn)

require.True(t, hasConn(mux, "test", false))

err = conn.Close()
for i := 0; i < 10; i++ {
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr())
require.NoError(t, err)
require.Equal(t, mc1, mc)
}

// Should be different even if the address is the same
mc1, err := m.GetConn("a", conns[0].LocalAddr())
require.NoError(t, err)
require.NotEqual(t, mc1, mc)
}

require.False(t, hasConn(mux, "test", false))
func TestMuxedConnection(t *testing.T) {
c := newPacketConn(t)
m := NewUDPMux(c)
m.Start()
defer m.Close()

msgCount := 3
connCount := 3

ufrags := []string{"a", "b", "c"}
var mu sync.Mutex
addrUfragMap := make(map[string]string)
for _, ufrag := range ufrags {
go func(ufrag string) {
for i := 0; i < connCount; i++ {
cc := newPacketConn(t)
mu.Lock()
addrUfragMap[cc.LocalAddr().String()] = ufrag
mu.Unlock()
setupMapping(t, ufrag, cc, m)
for j := 0; j < msgCount; j++ {
cc.WriteTo([]byte(ufrag), c.LocalAddr())
}
}
}(ufrag)
}

for _, ufrag := range ufrags {
mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant
require.NoError(t, err)
for i := 0; i < connCount; i++ {
msg := make([]byte, 100)
// Read the binding request
_, addr1, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addrUfragMap[addr1.String()], ufrag)
// Read individual msgs
for i := 0; i < msgCount; i++ {
n, addr2, err := mc.ReadFrom(msg)
require.NoError(t, err)
require.Equal(t, addr2, addr1)
require.Equal(t, ufrag, string(msg[:n]))
}
delete(addrUfragMap, addr1.String())
}
}
require.Equal(t, len(addrUfragMap), 0)
}

0 comments on commit 5579267

Please sign in to comment.