-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
202 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,227 @@ | ||
package udpmux | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/pion/stun" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
var _ net.PacketConn = dummyPacketConn{} | ||
|
||
type dummyPacketConn struct{} | ||
|
||
// Close implements net.PacketConn | ||
func (dummyPacketConn) Close() error { | ||
return nil | ||
} | ||
|
||
// LocalAddr implements net.PacketConn | ||
func (dummyPacketConn) LocalAddr() net.Addr { | ||
return nil | ||
} | ||
|
||
// ReadFrom implements net.PacketConn | ||
func (dummyPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { | ||
return 0, &net.UDPAddr{}, nil | ||
} | ||
|
||
// SetDeadline implements net.PacketConn | ||
func (dummyPacketConn) SetDeadline(t time.Time) error { | ||
return nil | ||
func getSTUNBindingRequest(ufrag string) *stun.Message { | ||
msg := stun.New() | ||
msg.SetType(stun.BindingRequest) | ||
uattr := stun.RawAttribute{ | ||
Type: stun.AttrUsername, | ||
Value: []byte(fmt.Sprintf("%s:%s", ufrag, ufrag)), // This is the format we expect in our connections | ||
} | ||
uattr.AddTo(msg) | ||
msg.Encode() | ||
return msg | ||
} | ||
|
||
// SetReadDeadline implements net.PacketConn | ||
func (dummyPacketConn) SetReadDeadline(t time.Time) error { | ||
return nil | ||
func setupMapping(t *testing.T, ufrag string, from net.PacketConn, m *UDPMux) { | ||
t.Helper() | ||
msg := getSTUNBindingRequest(ufrag) | ||
_, err := from.WriteTo(msg.Raw, m.GetListenAddresses()[0]) | ||
require.NoError(t, err) | ||
} | ||
|
||
// SetWriteDeadline implements net.PacketConn | ||
func (dummyPacketConn) SetWriteDeadline(t time.Time) error { | ||
return nil | ||
func newPacketConn(t *testing.T) net.PacketConn { | ||
t.Helper() | ||
udpPort0 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} | ||
c, err := net.ListenUDP("udp", udpPort0) | ||
require.NoError(t, err) | ||
t.Cleanup(func() { c.Close() }) | ||
return c | ||
} | ||
|
||
// WriteTo implements net.PacketConn | ||
func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { | ||
return 0, nil | ||
func TestAccept(t *testing.T) { | ||
c := newPacketConn(t) | ||
defer c.Close() | ||
m := NewUDPMux(c) | ||
m.Start() | ||
defer m.Close() | ||
|
||
ufrags := []string{"a", "b", "c", "d"} | ||
conns := make([]net.PacketConn, len(ufrags)) | ||
for i, ufrag := range ufrags { | ||
conns[i] = newPacketConn(t) | ||
setupMapping(t, ufrag, conns[i], m) | ||
} | ||
for i, ufrag := range ufrags { | ||
c, err := m.Accept(context.Background()) | ||
require.NoError(t, err) | ||
require.Equal(t, c.Ufrag, ufrag) | ||
require.Equal(t, c.Addr, conns[i].LocalAddr()) | ||
} | ||
|
||
for i, ufrag := range ufrags { | ||
// should not be accepted | ||
setupMapping(t, ufrag, conns[i], m) | ||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) | ||
defer cancel() | ||
_, err := m.Accept(ctx) | ||
require.Error(t, err) | ||
|
||
// should not be accepted | ||
cc := newPacketConn(t) | ||
setupMapping(t, ufrag, cc, m) | ||
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) | ||
defer cancel() | ||
_, err = m.Accept(ctx) | ||
require.Error(t, err) | ||
} | ||
} | ||
|
||
func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool { | ||
m.mx.Lock() | ||
_, ok := m.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}] | ||
m.mx.Unlock() | ||
return ok | ||
func TestGetConn(t *testing.T) { | ||
c := newPacketConn(t) | ||
m := NewUDPMux(c) | ||
m.Start() | ||
defer m.Close() | ||
|
||
ufrags := []string{"a", "b", "c", "d"} | ||
conns := make([]net.PacketConn, len(ufrags)) | ||
for i, ufrag := range ufrags { | ||
conns[i] = newPacketConn(t) | ||
setupMapping(t, ufrag, conns[i], m) | ||
} | ||
for i, ufrag := range ufrags { | ||
c, err := m.Accept(context.Background()) | ||
require.NoError(t, err) | ||
require.Equal(t, c.Ufrag, ufrag) | ||
require.Equal(t, c.Addr, conns[i].LocalAddr()) | ||
} | ||
|
||
for i, ufrag := range ufrags { | ||
c, err := m.GetConn(ufrag, conns[i].LocalAddr()) | ||
require.NoError(t, err) | ||
msg := make([]byte, 100) | ||
_, _, err = c.ReadFrom(msg) | ||
require.NoError(t, err) | ||
} | ||
|
||
for i, ufrag := range ufrags { | ||
cc := newPacketConn(t) | ||
// setupMapping of cc to ufrags[0] and remove the stun binding request from the queue | ||
setupMapping(t, ufrag, cc, m) | ||
mc, err := m.GetConn(ufrag, cc.LocalAddr()) | ||
require.NoError(t, err) | ||
msg := make([]byte, 100) | ||
_, _, err = mc.ReadFrom(msg) | ||
require.NoError(t, err) | ||
|
||
// Write from new connection should provide the new address on ReadFrom | ||
_, err = cc.WriteTo([]byte("test1"), c.LocalAddr()) | ||
require.NoError(t, err) | ||
n, addr, err := mc.ReadFrom(msg) | ||
require.NoError(t, err) | ||
require.Equal(t, addr, cc.LocalAddr()) | ||
require.Equal(t, string(msg[:n]), "test1") | ||
|
||
// Write from original connection should provide the original address | ||
_, err = conns[i].WriteTo([]byte("test2"), c.LocalAddr()) | ||
require.NoError(t, err) | ||
n, addr, err = mc.ReadFrom(msg) | ||
require.NoError(t, err) | ||
require.Equal(t, addr, conns[i].LocalAddr()) | ||
require.Equal(t, string(msg[:n]), "test2") | ||
} | ||
} | ||
|
||
var ( | ||
addrV4 = net.UDPAddr{IP: net.IPv4zero, Port: 1234} | ||
addrV6 = net.UDPAddr{IP: net.IPv6zero, Port: 1234} | ||
) | ||
|
||
func TestUDPMux_GetConn(t *testing.T) { | ||
m := NewUDPMux(dummyPacketConn{}) | ||
require.False(t, hasConn(m, "test", false)) | ||
conn, err := m.GetConn("test", &addrV4) | ||
func TestRemoveConnByUfrag(t *testing.T) { | ||
c := newPacketConn(t) | ||
m := NewUDPMux(c) | ||
m.Start() | ||
defer m.Close() | ||
|
||
// Map each ufrag to two addresses | ||
ufrag := "a" | ||
count := 10 | ||
conns := make([]net.PacketConn, count) | ||
for i := 0; i < 10; i++ { | ||
conns[i] = newPacketConn(t) | ||
setupMapping(t, ufrag, conns[i], m) | ||
} | ||
mc, err := m.GetConn(ufrag, conns[0].LocalAddr()) | ||
require.NoError(t, err) | ||
require.NotNil(t, conn) | ||
|
||
require.False(t, hasConn(m, "test", true)) | ||
connv6, err := m.GetConn("test", &addrV6) | ||
require.NoError(t, err) | ||
require.NotNil(t, connv6) | ||
|
||
require.NotEqual(t, conn, connv6) | ||
} | ||
|
||
func TestUDPMux_RemoveConnectionOnClose(t *testing.T) { | ||
mux := NewUDPMux(dummyPacketConn{}) | ||
conn, err := mux.GetConn("test", &addrV4) | ||
for i := 0; i < 10; i++ { | ||
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) | ||
require.NoError(t, err) | ||
require.Equal(t, mc1, mc) | ||
} | ||
|
||
// Now remove the ufrag | ||
m.RemoveConnByUfrag(ufrag) | ||
|
||
// All connections should now be associated with b | ||
ufrag = "b" | ||
for i := 0; i < 10; i++ { | ||
setupMapping(t, ufrag, conns[i], m) | ||
} | ||
mc, err = m.GetConn(ufrag, conns[0].LocalAddr()) | ||
require.NoError(t, err) | ||
require.NotNil(t, conn) | ||
|
||
require.True(t, hasConn(mux, "test", false)) | ||
|
||
err = conn.Close() | ||
for i := 0; i < 10; i++ { | ||
mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) | ||
require.NoError(t, err) | ||
require.Equal(t, mc1, mc) | ||
} | ||
|
||
// Should be different even if the address is the same | ||
mc1, err := m.GetConn("a", conns[0].LocalAddr()) | ||
require.NoError(t, err) | ||
require.NotEqual(t, mc1, mc) | ||
} | ||
|
||
require.False(t, hasConn(mux, "test", false)) | ||
func TestMuxedConnection(t *testing.T) { | ||
c := newPacketConn(t) | ||
m := NewUDPMux(c) | ||
m.Start() | ||
defer m.Close() | ||
|
||
msgCount := 3 | ||
connCount := 3 | ||
|
||
ufrags := []string{"a", "b", "c"} | ||
var mu sync.Mutex | ||
addrUfragMap := make(map[string]string) | ||
for _, ufrag := range ufrags { | ||
go func(ufrag string) { | ||
for i := 0; i < connCount; i++ { | ||
cc := newPacketConn(t) | ||
mu.Lock() | ||
addrUfragMap[cc.LocalAddr().String()] = ufrag | ||
mu.Unlock() | ||
setupMapping(t, ufrag, cc, m) | ||
for j := 0; j < msgCount; j++ { | ||
cc.WriteTo([]byte(ufrag), c.LocalAddr()) | ||
} | ||
} | ||
}(ufrag) | ||
} | ||
|
||
for _, ufrag := range ufrags { | ||
mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant | ||
require.NoError(t, err) | ||
for i := 0; i < connCount; i++ { | ||
msg := make([]byte, 100) | ||
// Read the binding request | ||
_, addr1, err := mc.ReadFrom(msg) | ||
require.NoError(t, err) | ||
require.Equal(t, addrUfragMap[addr1.String()], ufrag) | ||
// Read individual msgs | ||
for i := 0; i < msgCount; i++ { | ||
n, addr2, err := mc.ReadFrom(msg) | ||
require.NoError(t, err) | ||
require.Equal(t, addr2, addr1) | ||
require.Equal(t, ufrag, string(msg[:n])) | ||
} | ||
delete(addrUfragMap, addr1.String()) | ||
} | ||
} | ||
require.Equal(t, len(addrUfragMap), 0) | ||
} |