From bdd3fb675745f83d76feea4fde294c0b28cebd39 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sat, 9 Nov 2024 10:58:01 +0800 Subject: [PATCH] test: add test for SO_BINDTODEVICE with TCP (#652) --- client_test.go | 2 +- os_unix_test.go | 120 +++++++++++++++++++++++++++++++-------------- os_windows_test.go | 2 +- 3 files changed, 86 insertions(+), 38 deletions(-) diff --git a/client_test.go b/client_test.go index 8793293f6..47a454a99 100644 --- a/client_test.go +++ b/client_test.go @@ -527,7 +527,7 @@ func startGnetClient(t *testing.T, cli *Client, network, addr string, multicore, } if netDial { var netConn net.Conn - netConn, err = NetDial(network, addr) + netConn, err = stdDial(network, addr) require.NoError(t, err) c, err = cli.EnrollContext(netConn, handler) } else { diff --git a/os_unix_test.go b/os_unix_test.go index de428be20..2ff97ed11 100644 --- a/os_unix_test.go +++ b/os_unix_test.go @@ -12,6 +12,7 @@ import ( "net" "regexp" "runtime" + "strings" "sync" "sync/atomic" "testing" @@ -27,7 +28,7 @@ import ( var ( SysClose = unix.Close - NetDial = net.Dial + stdDial = net.Dial ) // NOTE: TestServeMulticast can fail with "write: no buffer space available" on Wi-Fi interface. @@ -244,20 +245,31 @@ func getInterfaceIP(ifname string, ipv4 bool) (net.IP, error) { return nil, errors.New("no valid IP address found") } -type testBindToDeviceServer struct { +type testBindToDeviceServer[T interface{ *net.TCPAddr | *net.UDPAddr }] struct { BuiltinEventEngine tester *testing.T data []byte packets atomic.Int32 expectedPackets int32 network string - loopBackIP net.IP - eth0IP net.IP - broadcastIP net.IP - zone string + loopBackAddr T + eth0Addr T + broadcastAddr T +} + +func netDial[T *net.TCPAddr | *net.UDPAddr](network string, a T) (net.Conn, error) { + addr := any(a) + switch v := addr.(type) { + case *net.TCPAddr: + return net.DialTCP(network, nil, v) + case *net.UDPAddr: + return net.DialUDP(network, nil, v) + default: + return nil, errors.New("unsupported address type") + } } -func (s *testBindToDeviceServer) OnTraffic(c Conn) (action Action) { +func (s *testBindToDeviceServer[T]) OnTraffic(c Conn) (action Action) { b, err := c.Next(-1) assert.NoError(s.tester, err) assert.EqualValues(s.tester, s.data, b) @@ -267,30 +279,34 @@ func (s *testBindToDeviceServer) OnTraffic(c Conn) (action Action) { return } -func (s *testBindToDeviceServer) OnShutdown(_ Engine) { +func (s *testBindToDeviceServer[T]) OnShutdown(_ Engine) { assert.EqualValues(s.tester, s.expectedPackets, s.packets.Load()) } -func (s *testBindToDeviceServer) OnTick() (delay time.Duration, action Action) { +func (s *testBindToDeviceServer[T]) OnTick() (delay time.Duration, action Action) { // Send a packet to the loopback interface, it should never make its way to the server // because we've bound the server to eth0. - lp, err := findLoopbackInterface() - assert.NoError(s.tester, err) - c, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.loopBackIP, Port: 9999, Zone: lp.Name}) - assert.NoError(s.tester, err) - defer c.Close() - _, err = c.Write(s.data) - assert.NoError(s.tester, err) + c, err := netDial(s.network, s.loopBackAddr) + if strings.HasPrefix(s.network, "tcp") { + assert.ErrorContains(s.tester, err, "connection refused") + } else { + assert.NoError(s.tester, err) + defer c.Close() + _, err = c.Write(s.data) + assert.NoError(s.tester, err) + } - // Send a packet to the broadcast address, it should reach the server. - c6, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.broadcastIP, Port: 9999, Zone: s.zone}) - assert.NoError(s.tester, err) - defer c6.Close() - _, err = c6.Write(s.data) - assert.NoError(s.tester, err) + if s.broadcastAddr != nil { + // Send a packet to the broadcast address, it should reach the server. + c6, err := netDial(s.network, s.broadcastAddr) + assert.NoError(s.tester, err) + defer c6.Close() + _, err = c6.Write(s.data) + assert.NoError(s.tester, err) + } // Send a packet to the eth0 interface, it should reach the server. - c4, err := net.DialUDP(s.network, nil, &net.UDPAddr{IP: s.eth0IP, Port: 9999, Zone: s.zone}) + c4, err := netDial(s.network, s.eth0Addr) assert.NoError(s.tester, err) defer c4.Close() _, err = c4.Write(s.data) @@ -305,28 +321,44 @@ func (s *testBindToDeviceServer) OnTick() (delay time.Duration, action Action) { func TestBindToDevice(t *testing.T) { if runtime.GOOS != "linux" { - err := Run(&testBindToDeviceServer{}, "udp://:9999", WithBindToDevice("eth0")) + err := Run(&testBindToDeviceServer[*net.UDPAddr]{}, "tcp://:9999", WithBindToDevice("eth0")) assert.ErrorIs(t, err, errorx.ErrUnsupportedOp) return } + lp, err := findLoopbackInterface() + assert.NoError(t, err) dev, err := detectLinuxEthernetInterfaceName() assert.NoErrorf(t, err, "no testable Ethernet interface found") t.Logf("detected Ethernet interface: %s", dev) data := []byte("hello") t.Run("IPv4", func(t *testing.T) { - t.Run("UDP", func(t *testing.T) { - ip, err := getInterfaceIP(dev, true) + ip, err := getInterfaceIP(dev, true) + assert.NoError(t, err) + t.Run("TCP", func(t *testing.T) { + ts := &testBindToDeviceServer[*net.TCPAddr]{ + tester: t, + data: data, + expectedPackets: 1, + network: "tcp", + loopBackAddr: &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9999, Zone: ""}, + eth0Addr: &net.TCPAddr{IP: ip, Port: 9999, Zone: ""}, + } + require.NoError(t, err) + err = Run(ts, "tcp://0.0.0.0:9999", + WithTicker(true), + WithBindToDevice(dev)) assert.NoError(t, err) - ts := &testBindToDeviceServer{ + }) + t.Run("UDP", func(t *testing.T) { + ts := &testBindToDeviceServer[*net.UDPAddr]{ tester: t, data: data, expectedPackets: 2, network: "udp", - loopBackIP: net.IPv4(127, 0, 0, 1), - eth0IP: ip, - broadcastIP: net.IPv4bcast, - zone: dev, + loopBackAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 9999, Zone: ""}, + eth0Addr: &net.UDPAddr{IP: ip, Port: 9999, Zone: ""}, + broadcastAddr: &net.UDPAddr{IP: net.IPv4bcast, Port: 9999, Zone: ""}, } require.NoError(t, err) err = Run(ts, "udp://0.0.0.0:9999", @@ -336,18 +368,34 @@ func TestBindToDevice(t *testing.T) { }) }) t.Run("IPv6", func(t *testing.T) { + t.Run("TCP", func(t *testing.T) { + ip, err := getInterfaceIP(dev, false) + assert.NoError(t, err) + ts := &testBindToDeviceServer[*net.TCPAddr]{ + tester: t, + data: data, + expectedPackets: 1, + network: "tcp6", + loopBackAddr: &net.TCPAddr{IP: net.IPv6loopback, Port: 9999, Zone: lp.Name}, + eth0Addr: &net.TCPAddr{IP: ip, Port: 9999, Zone: dev}, + } + require.NoError(t, err) + err = Run(ts, "tcp6://[::]:9999", + WithTicker(true), + WithBindToDevice(dev)) + assert.NoError(t, err) + }) t.Run("UDP", func(t *testing.T) { ip, err := getInterfaceIP(dev, false) assert.NoError(t, err) - ts := &testBindToDeviceServer{ + ts := &testBindToDeviceServer[*net.UDPAddr]{ tester: t, data: data, expectedPackets: 2, network: "udp6", - loopBackIP: net.IPv6loopback, - eth0IP: ip, - broadcastIP: net.IPv6linklocalallnodes, - zone: dev, + loopBackAddr: &net.UDPAddr{IP: net.IPv6loopback, Port: 9999, Zone: lp.Name}, + eth0Addr: &net.UDPAddr{IP: ip, Port: 9999, Zone: dev}, + broadcastAddr: &net.UDPAddr{IP: net.IPv6linklocalallnodes, Port: 9999, Zone: dev}, } require.NoError(t, err) err = Run(ts, "udp6://[::]:9999", diff --git a/os_windows_test.go b/os_windows_test.go index c9e8fa9b0..99ec4eeb3 100644 --- a/os_windows_test.go +++ b/os_windows_test.go @@ -12,7 +12,7 @@ func SysClose(fd int) error { return syscall.CloseHandle(syscall.Handle(fd)) } -func NetDial(network, addr string) (net.Conn, error) { +func stdDial(network, addr string) (net.Conn, error) { if network == "unix" { laddr, _ := net.ResolveUnixAddr(network, unixAddr(addr)) raddr, _ := net.ResolveUnixAddr(network, addr)