diff --git a/vnet/chunk.go b/vnet/chunk.go index d1cb4c8e..f581936b 100644 --- a/vnet/chunk.go +++ b/vnet/chunk.go @@ -212,8 +212,8 @@ type chunkTCP struct { destinationPort int flags tcpFlag // control bits userData []byte // only with PSH flag - // seq uint32 // always starts with 0 - // ack uint32 // always starts with 0 + seqNum uint32 // data sequence (vnet-internal) + ackNum uint32 // ACK for a seqNum (vnet-internal) } func newChunkTCP(srcAddr, dstAddr *net.TCPAddr, flags tcpFlag) *chunkTCP { @@ -256,15 +256,20 @@ func (c *chunkTCP) Clone() Chunk { timestamp: c.timestamp, sourceIP: c.sourceIP, destinationIP: c.destinationIP, + tag: c.tag, + duplicate: c.duplicate, }, sourcePort: c.sourcePort, destinationPort: c.destinationPort, + flags: c.flags, userData: userData, + seqNum: c.seqNum, + ackNum: c.ackNum, } } func (c *chunkTCP) Network() string { - return "tcp" + return tcp } func (c *chunkTCP) String() string { diff --git a/vnet/chunk_test.go b/vnet/chunk_test.go index be22d717..3453a6c6 100644 --- a/vnet/chunk_test.go +++ b/vnet/chunk_test.go @@ -90,7 +90,7 @@ func TestChunk(t *testing.T) { var chunk Chunk = newChunkTCP(src, dst, tcpSYN) str := chunk.String() log.Debugf("chunk: %s", str) - assert.Equal(t, "tcp", chunk.Network(), "should match") + assert.Equal(t, tcp, chunk.Network(), "should match") assert.True(t, strings.Contains(str, src.Network()), "should include network type") assert.True(t, strings.Contains(str, src.String()), "should include address") assert.True(t, strings.Contains(str, dst.String()), "should include address") diff --git a/vnet/errors.go b/vnet/errors.go index 22c7c2d3..363f11b2 100644 --- a/vnet/errors.go +++ b/vnet/errors.go @@ -7,7 +7,7 @@ type timeoutError struct { msg string } -func newTimeoutError(msg string) error { +func newTimeoutError(msg string) error { // nolint:unparam return &timeoutError{ msg: msg, } diff --git a/vnet/nat.go b/vnet/nat.go index f4722af4..57d84002 100644 --- a/vnet/nat.go +++ b/vnet/nat.go @@ -14,12 +14,12 @@ import ( ) var ( - errNATRequriesMapping = errors.New("1:1 NAT requires more than one mapping") - errMismatchLengthIP = errors.New("length mismtach between mappedIPs and localIPs") - errNonUDPTranslationNotSupported = errors.New("non-udp translation is not supported yet") - errNoAssociatedLocalAddress = errors.New("no associated local address") - errNoNATBindingFound = errors.New("no NAT binding found") - errHasNoPermission = errors.New("has no permission") + errNATRequiresMapping = errors.New("1:1 NAT requires more than one mapping") + errMismatchLengthIP = errors.New("length mismatch between mappedIPs and localIPs") + errTranslationNotSupported = errors.New("translation is not supported for this protocol") + errNoAssociatedLocalAddress = errors.New("no associated local address") + errNoNATBindingFound = errors.New("no NAT binding found") + errHasNoPermission = errors.New("has no permission") ) // EndpointDependencyType defines a type of behavioral dependendency on the @@ -92,6 +92,7 @@ type networkAddressTranslator struct { outboundMap map[string]*mapping // key: "::[:remote-ip[:remote-port]] inboundMap map[string]*mapping // key: "::" udpPortCounter int + tcpPortCounter int mutex sync.RWMutex log logging.LeveledLogger } @@ -107,7 +108,7 @@ func newNAT(config *natConfig) (*networkAddressTranslator, error) { natType.MappingLifeTime = 0 if len(config.mappedIPs) == 0 { - return nil, errNATRequriesMapping + return nil, errNATRequiresMapping } if len(config.mappedIPs) != len(config.localIPs) { return nil, errMismatchLengthIP @@ -151,13 +152,70 @@ func (n *networkAddressTranslator) getPairedLocalIP(mappedIP net.IP) net.IP { return nil } -func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) { //nolint:cyclop +func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) { //nolint:cyclop,gocognit n.mutex.Lock() defer n.mutex.Unlock() to := from.Clone() - if from.Network() == udp { //nolint:nestif + translateOutboundNAPT := func(proto string, portBase int, portCounter *int) (Chunk, error) { + var bound, filterKey string + switch n.natType.MappingBehavior { + case EndpointIndependent: + bound = "" + case EndpointAddrDependent: + bound = from.getDestinationIP().String() + case EndpointAddrPortDependent: + bound = from.DestinationAddr().String() + } + + switch n.natType.FilteringBehavior { + case EndpointIndependent: + filterKey = "" + case EndpointAddrDependent: + filterKey = from.getDestinationIP().String() + case EndpointAddrPortDependent: + filterKey = from.DestinationAddr().String() + } + + oKey := fmt.Sprintf("%s:%s:%s", proto, from.SourceAddr().String(), bound) + + mapp := n.findOutboundMapping(oKey) + if mapp == nil { + mappedPort := portBase + *portCounter + (*portCounter)++ + + mapp = &mapping{ + proto: from.SourceAddr().Network(), + local: from.SourceAddr().String(), + bound: bound, + mapped: fmt.Sprintf("%s:%d", n.mappedIPs[0].String(), mappedPort), + filters: map[string]struct{}{}, + expires: time.Now().Add(n.natType.MappingLifeTime), + } + + n.outboundMap[oKey] = mapp + iKey := fmt.Sprintf("%s:%s", proto, mapp.mapped) + + n.log.Debugf("[%s] created a new NAT binding oKey=%s iKey=%s", n.name, oKey, iKey) + + mapp.filters[filterKey] = struct{}{} + n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) + n.inboundMap[iKey] = mapp + } else if _, ok := mapp.filters[filterKey]; !ok { + n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) + mapp.filters[filterKey] = struct{}{} + } + + if err := to.setSourceAddr(mapp.mapped); err != nil { + return nil, err + } + + return to, nil + } + + switch from.Network() { + case udp: if n.natType.Mode == NATModeNAT1To1 { // 1:1 NAT behavior srcAddr := from.SourceAddr().(*net.UDPAddr) //nolint:forcetypeassert @@ -172,61 +230,34 @@ func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) return nil, err } } else { - // Normal (NAPT) behavior - var bound, filterKey string - switch n.natType.MappingBehavior { - case EndpointIndependent: - bound = "" - case EndpointAddrDependent: - bound = from.getDestinationIP().String() - case EndpointAddrPortDependent: - bound = from.DestinationAddr().String() + var err error + to, err = translateOutboundNAPT("udp", 0xC000, &n.udpPortCounter) + if err != nil { + return nil, err } + } - switch n.natType.FilteringBehavior { - case EndpointIndependent: - filterKey = "" - case EndpointAddrDependent: - filterKey = from.getDestinationIP().String() - case EndpointAddrPortDependent: - filterKey = from.DestinationAddr().String() - } + n.log.Debugf("[%s] translate outbound chunk from %s to %s", n.name, from.String(), to.String()) - oKey := fmt.Sprintf("udp:%s:%s", from.SourceAddr().String(), bound) - - mapp := n.findOutboundMapping(oKey) - if mapp == nil { - // Create a new mapping - mappedPort := 0xC000 + n.udpPortCounter - n.udpPortCounter++ - - mapp = &mapping{ - proto: from.SourceAddr().Network(), - local: from.SourceAddr().String(), - bound: bound, - mapped: fmt.Sprintf("%s:%d", n.mappedIPs[0].String(), mappedPort), - filters: map[string]struct{}{}, - expires: time.Now().Add(n.natType.MappingLifeTime), - } - - n.outboundMap[oKey] = mapp - - iKey := fmt.Sprintf("udp:%s", mapp.mapped) - - n.log.Debugf("[%s] created a new NAT binding oKey=%s iKey=%s", - n.name, - oKey, - iKey) - - mapp.filters[filterKey] = struct{}{} - n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) - n.inboundMap[iKey] = mapp - } else if _, ok := mapp.filters[filterKey]; !ok { - n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) - mapp.filters[filterKey] = struct{}{} - } + return to, nil - if err := to.setSourceAddr(mapp.mapped); err != nil { + case tcp: + if n.natType.Mode == NATModeNAT1To1 { + srcAddr := from.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert + srcIP := n.getPairedMappedIP(srcAddr.IP) + if srcIP == nil { + n.log.Debugf("[%s] drop outbound chunk %s with not route", n.name, from.String()) + + return nil, nil // nolint:nilnil + } + srcPort := srcAddr.Port + if err := to.setSourceAddr(fmt.Sprintf("%s:%d", srcIP.String(), srcPort)); err != nil { + return nil, err + } + } else { + var err error + to, err = translateOutboundNAPT("tcp", 0x8000, &n.tcpPortCounter) + if err != nil { return nil, err } } @@ -234,60 +265,99 @@ func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) n.log.Debugf("[%s] translate outbound chunk from %s to %s", n.name, from.String(), to.String()) return to, nil - } - return nil, errNonUDPTranslationNotSupported + default: + return nil, errTranslationNotSupported + } } -func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) { //nolint:cyclop +func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) { //nolint:cyclop,gocognit n.mutex.Lock() defer n.mutex.Unlock() to := from.Clone() - if from.Network() == udp { //nolint:nestif + translateInboundNAT1To1 := func(dstPort int) (Chunk, error) { + dstIP := n.getPairedLocalIP(from.getDestinationIP()) + if dstIP == nil { + return nil, fmt.Errorf("drop %s as %w", from.String(), errNoAssociatedLocalAddress) + } + if err := to.setDestinationAddr(fmt.Sprintf("%s:%d", dstIP, dstPort)); err != nil { + return nil, err + } + + return to, nil + } + + translateInboundNAPT := func(proto string) (Chunk, error) { + iKey := fmt.Sprintf("%s:%s", proto, from.DestinationAddr().String()) + mapp := n.findInboundMapping(iKey) + if mapp == nil { + return nil, fmt.Errorf("drop %s as %w", from.String(), errNoNATBindingFound) + } + + var filterKey string + switch n.natType.FilteringBehavior { + case EndpointIndependent: + filterKey = "" + case EndpointAddrDependent: + filterKey = from.getSourceIP().String() + case EndpointAddrPortDependent: + filterKey = from.SourceAddr().String() + } + + if _, ok := mapp.filters[filterKey]; !ok { + return nil, fmt.Errorf("drop %s as the remote %s %w", from.String(), filterKey, errHasNoPermission) + } + + // See RFC 4847 Section 4.3. Mapping Refresh + // a) Inbound refresh may be useful for applications with no outgoing + // UDP traffic. However, allowing inbound refresh may allow an + // external attacker or misbehaving application to keep a mapping + // alive indefinitely. This may be a security risk. Also, if the + // process is repeated with different ports, over time, it could + // use up all the ports on the NAT. + + if err := to.setDestinationAddr(mapp.local); err != nil { + return nil, err + } + + return to, nil + } + + switch from.Network() { + case udp: if n.natType.Mode == NATModeNAT1To1 { - // 1:1 NAT behavior dstAddr := from.DestinationAddr().(*net.UDPAddr) //nolint:forcetypeassert - dstIP := n.getPairedLocalIP(dstAddr.IP) - if dstIP == nil { - return nil, fmt.Errorf("drop %s as %w", from.String(), errNoAssociatedLocalAddress) - } - dstPort := from.DestinationAddr().(*net.UDPAddr).Port //nolint:forcetypeassert - if err := to.setDestinationAddr(fmt.Sprintf("%s:%d", dstIP, dstPort)); err != nil { + var err error + to, err = translateInboundNAT1To1(dstAddr.Port) + if err != nil { return nil, err } } else { - // Normal (NAPT) behavior - iKey := fmt.Sprintf("udp:%s", from.DestinationAddr().String()) - mapping := n.findInboundMapping(iKey) - if mapping == nil { - return nil, fmt.Errorf("drop %s as %w", from.String(), errNoNATBindingFound) - } - - var filterKey string - switch n.natType.FilteringBehavior { - case EndpointIndependent: - filterKey = "" - case EndpointAddrDependent: - filterKey = from.getSourceIP().String() - case EndpointAddrPortDependent: - filterKey = from.SourceAddr().String() + var err error + to, err = translateInboundNAPT(udp) + if err != nil { + return nil, err } + } - if _, ok := mapping.filters[filterKey]; !ok { - return nil, fmt.Errorf("drop %s as the remote %s %w", from.String(), filterKey, errHasNoPermission) - } + n.log.Debugf("[%s] translate inbound chunk from %s to %s", n.name, from.String(), to.String()) - // See RFC 4847 Section 4.3. Mapping Refresh - // a) Inbound refresh may be useful for applications with no outgoing - // UDP traffic. However, allowing inbound refresh may allow an - // external attacker or misbehaving application to keep a mapping - // alive indefinitely. This may be a security risk. Also, if the - // process is repeated with different ports, over time, it could - // use up all the ports on the NAT. + return to, nil - if err := to.setDestinationAddr(mapping.local); err != nil { + case tcp: + if n.natType.Mode == NATModeNAT1To1 { + dstAddr := from.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + var err error + to, err = translateInboundNAT1To1(dstAddr.Port) + if err != nil { + return nil, err + } + } else { + var err error + to, err = translateInboundNAPT(tcp) + if err != nil { return nil, err } } @@ -295,9 +365,10 @@ func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) { n.log.Debugf("[%s] translate inbound chunk from %s to %s", n.name, from.String(), to.String()) return to, nil - } - return nil, errNonUDPTranslationNotSupported + default: + return nil, errTranslationNotSupported + } } // caller must hold the mutex. diff --git a/vnet/nat_test.go b/vnet/nat_test.go index 9a0b06b6..3e12bcdb 100644 --- a/vnet/nat_test.go +++ b/vnet/nat_test.go @@ -19,6 +19,70 @@ import ( const demoIP = "1.2.3.4" +type natTestProto struct { + name string + newAddr func(ip net.IP, port int) net.Addr + newChunk func(t *testing.T, src, dst net.Addr) Chunk +} + +func natTestProtos() []natTestProto { + return []natTestProto{ + { + name: "udp", + newAddr: func(ip net.IP, port int) net.Addr { + return &net.UDPAddr{IP: ip, Port: port} + }, + newChunk: func(t *testing.T, src, dst net.Addr) Chunk { + t.Helper() + srcAddr, ok := src.(*net.UDPAddr) + if !ok { + assert.FailNow(t, "expected *net.UDPAddr src, got %T", src) + } + dstAddr, ok := dst.(*net.UDPAddr) + if !ok { + assert.FailNow(t, "expected *net.UDPAddr dst, got %T", dst) + } + + return newChunkUDP(srcAddr, dstAddr) + }, + }, + { + name: "tcp", + newAddr: func(ip net.IP, port int) net.Addr { + return &net.TCPAddr{IP: ip, Port: port} + }, + newChunk: func(t *testing.T, src, dst net.Addr) Chunk { + t.Helper() + srcAddr, ok := src.(*net.TCPAddr) + if !ok { + assert.FailNow(t, "expected *net.TCPAddr src, got %T", src) + } + dstAddr, ok := dst.(*net.TCPAddr) + if !ok { + assert.FailNow(t, "expected *net.TCPAddr dst, got %T", dst) + } + + return newChunkTCP(srcAddr, dstAddr, tcpACK) + }, + }, + } +} + +func natAddrIPPort(t *testing.T, addr net.Addr) (net.IP, int) { + t.Helper() + + switch a := addr.(type) { + case *net.UDPAddr: + return a.IP, a.Port + case *net.TCPAddr: + return a.IP, a.Port + default: + assert.FailNow(t, "unexpected addr type %T", addr) + + return nil, 0 + } +} + func TestNATTypeDefaults(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() nat, err := newNAT(&natConfig{ @@ -39,502 +103,299 @@ func TestNATMappingBehavior(t *testing.T) { //nolint:maintidx loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") - t.Run("full-cone NAT", func(t *testing.T) { - nat, err := newNAT(&natConfig{ - natType: NATType{ - MappingBehavior: EndpointIndependent, - FilteringBehavior: EndpointIndependent, - Hairpinning: false, - MappingLifeTime: 30 * time.Second, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - loggerFactory: loggerFactory, - }) - assert.NoError(t, err, "should succeed") - - src := &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - } - dst := &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - } - - oic := newChunkUDP(src, dst) - - oec, err := nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - - log.Debugf("o-original : %s", oic.String()) - log.Debugf("o-translated: %s", oec.String()) - - //nolint:forcetypeassert - iec := newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - log.Debugf("i-original : %s", iec.String()) - - iic, err := nat.translateInbound(iec) - assert.Nil(t, err, "should succeed") - - log.Debugf("i-translated: %s", iic.String()) - - //nolint:forcetypeassert - assert.Equal(t, - oic.SourceAddr().String(), - iic.(*chunkUDP).DestinationAddr().String(), - "should match") - - // packet with dest addr that does not exist in the mapping table - // will be dropped - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort + 1, - }, - ) - - _, err = nat.translateInbound(iec) - log.Debug(err.Error()) - assert.NotNil(t, err, "should fail (dropped)") - - // packet from any addr will be accepted (full-cone) - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: 7777, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - _, err = nat.translateInbound(iec) - assert.Nil(t, err, "should succeed") - }) - - t.Run("addr-restricted-cone NAT", func(t *testing.T) { - nat, err := newNAT(&natConfig{ - natType: NATType{ - MappingBehavior: EndpointIndependent, - FilteringBehavior: EndpointAddrDependent, - Hairpinning: false, - MappingLifeTime: 30 * time.Second, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - loggerFactory: loggerFactory, - }) - assert.NoError(t, err, "should succeed") - - src := &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - } - dst := &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - } - - oic := newChunkUDP(src, dst) - log.Debugf("o-original : %s", oic.String()) - - oec, err := nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - log.Debugf("o-translated: %s", oec.String()) - - // sending different (IP: 5.6.7.9) won't create a new mapping - oic2 := newChunkUDP(&net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, &net.UDPAddr{ - IP: net.ParseIP("5.6.7.9"), - Port: 9000, - }) - oec2, err := nat.translateOutbound(oic2) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - log.Debugf("o-translated: %s", oec2.String()) - - //nolint:forcetypeassert - iec := newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - log.Debugf("i-original : %s", iec.String()) - - iic, err := nat.translateInbound(iec) - if !assert.NoError(t, err, "should succeed") { - return - } - - log.Debugf("i-translated: %s", iic.String()) - - //nolint:forcetypeassert - assert.Equal(t, - oic.SourceAddr().String(), - iic.(*chunkUDP).DestinationAddr().String(), - "should match") - - // packet with dest addr that does not exist in the mapping table - // will be dropped - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort + 1, - }, - ) - - _, err = nat.translateInbound(iec) - log.Debug(err.Error()) - assert.NotNil(t, err, "should fail (dropped)") - - // packet from any port will be accepted (restricted-cone) - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: 7777, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - _, err = nat.translateInbound(iec) - assert.Nil(t, err, "should succeed") - - // packet from different addr will be dropped (restricted-cone) - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("6.6.6.6"), - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - _, err = nat.translateInbound(iec) - log.Debug(err.Error()) - assert.NotNil(t, err, "should fail (dropped)") - }) - - t.Run("port-restricted-cone NAT", func(t *testing.T) { - nat, err := newNAT(&natConfig{ - natType: NATType{ - MappingBehavior: EndpointIndependent, - FilteringBehavior: EndpointAddrPortDependent, - Hairpinning: false, - MappingLifeTime: 30 * time.Second, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - loggerFactory: loggerFactory, - }) - assert.NoError(t, err, "should succeed") - - src := &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - } - dst := &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - } - - oic := newChunkUDP(src, dst) - log.Debugf("o-original : %s", oic.String()) - - oec, err := nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - - log.Debugf("o-translated: %s", oec.String()) - - // sending different (IP: 5.6.7.9) won't create a new mapping - oic2 := newChunkUDP(&net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, &net.UDPAddr{ - IP: net.ParseIP("5.6.7.9"), - Port: 9000, - }) - oec2, err := nat.translateOutbound(oic2) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - log.Debugf("o-translated: %s", oec2.String()) - - //nolint:forcetypeassert - iec := newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - log.Debugf("i-original : %s", iec.String()) - - iic, err := nat.translateInbound(iec) - assert.Nil(t, err, "should succeed") - - log.Debugf("i-translated: %s", iic.String()) - - //nolint:forcetypeassert - assert.Equal(t, - oic.SourceAddr().String(), - iic.(*chunkUDP).DestinationAddr().String(), - "should match") - - // packet with dest addr that does not exist in the mapping table - // will be dropped - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort + 1, - }, - ) - - _, err = nat.translateInbound(iec) - assert.NotNil(t, err, "should fail (dropped)") - - // packet from different port will be dropped (port-restricted-cone) - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: 7777, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - _, err = nat.translateInbound(iec) - assert.NotNil(t, err, "should fail (dropped)") - - // packet from different addr will be dropped (restricted-cone) - //nolint:forcetypeassert - iec = newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("6.6.6.6"), - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - _, err = nat.translateInbound(iec) - assert.NotNil(t, err, "should fail (dropped)") - }) - - t.Run("symmetric NAT addr dependent mapping", func(t *testing.T) { //nolint:dupl - nat, err := newNAT(&natConfig{ - natType: NATType{ - MappingBehavior: EndpointAddrDependent, - FilteringBehavior: EndpointAddrDependent, - Hairpinning: false, - MappingLifeTime: 30 * time.Second, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - loggerFactory: loggerFactory, - }) - assert.NoError(t, err, "should succeed") + for _, proto := range natTestProtos() { + proto := proto + t.Run(proto.name, func(t *testing.T) { + t.Run("full-cone NAT", func(t *testing.T) { + nat, err := newNAT(&natConfig{ + natType: NATType{ + MappingBehavior: EndpointIndependent, + FilteringBehavior: EndpointIndependent, + Hairpinning: false, + MappingLifeTime: 30 * time.Second, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + loggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + + srcIP := net.ParseIP("192.168.0.2") + srcPort := 1234 + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + oic := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(dstIP, dstPort)) + + oec, err := nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + + log.Debugf("o-original : %s", oic.String()) + log.Debugf("o-translated: %s", oec.String()) + + oecIP, oecPort := natAddrIPPort(t, oec.SourceAddr()) + iec := proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort)) + + log.Debugf("i-original : %s", iec.String()) + + iic, err := nat.translateInbound(iec) + assert.Nil(t, err, "should succeed") + + log.Debugf("i-translated: %s", iic.String()) + + assert.Equal(t, oic.SourceAddr().String(), iic.DestinationAddr().String(), "should match") + + // packet with dest addr that does not exist in the mapping table + // will be dropped + iec = proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort+1)) + + _, err = nat.translateInbound(iec) + log.Debug(err.Error()) + assert.NotNil(t, err, "should fail (dropped)") + + // packet from any addr will be accepted (full-cone) + iec = proto.newChunk(t, proto.newAddr(dstIP, 7777), proto.newAddr(oecIP, oecPort)) + + _, err = nat.translateInbound(iec) + assert.Nil(t, err, "should succeed") + }) - oic1 := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - }, - ) + t.Run("addr-restricted-cone NAT", func(t *testing.T) { + nat, err := newNAT(&natConfig{ + natType: NATType{ + MappingBehavior: EndpointIndependent, + FilteringBehavior: EndpointAddrDependent, + Hairpinning: false, + MappingLifeTime: 30 * time.Second, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + loggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + + srcIP := net.ParseIP("192.168.0.2") + srcPort := 1234 + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + oic := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(dstIP, dstPort)) + log.Debugf("o-original : %s", oic.String()) + + oec, err := nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + log.Debugf("o-translated: %s", oec.String()) + + // sending different (IP: 5.6.7.9) won't create a new mapping + oic2 := proto.newChunk(t, + proto.newAddr(srcIP, srcPort), + proto.newAddr(net.ParseIP("5.6.7.9"), 9000), + ) + oec2, err := nat.translateOutbound(oic2) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + log.Debugf("o-translated: %s", oec2.String()) + + oecIP, oecPort := natAddrIPPort(t, oec.SourceAddr()) + iec := proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort)) + + log.Debugf("i-original : %s", iec.String()) + + iic, err := nat.translateInbound(iec) + if !assert.NoError(t, err, "should succeed") { + return + } + + log.Debugf("i-translated: %s", iic.String()) + + assert.Equal(t, oic.SourceAddr().String(), iic.DestinationAddr().String(), "should match") + + // packet with dest addr that does not exist in the mapping table + // will be dropped + iec = proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort+1)) + + _, err = nat.translateInbound(iec) + log.Debug(err.Error()) + assert.NotNil(t, err, "should fail (dropped)") + + // packet from any port will be accepted (restricted-cone) + iec = proto.newChunk(t, proto.newAddr(dstIP, 7777), proto.newAddr(oecIP, oecPort)) + + _, err = nat.translateInbound(iec) + assert.Nil(t, err, "should succeed") + + // packet from different addr will be dropped (restricted-cone) + iec = proto.newChunk(t, proto.newAddr(net.ParseIP("6.6.6.6"), dstPort), proto.newAddr(oecIP, oecPort)) + + _, err = nat.translateInbound(iec) + log.Debug(err.Error()) + assert.NotNil(t, err, "should fail (dropped)") + }) - oic2 := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.100"), - Port: 5678, - }, - ) + t.Run("port-restricted-cone NAT", func(t *testing.T) { + nat, err := newNAT(&natConfig{ + natType: NATType{ + MappingBehavior: EndpointIndependent, + FilteringBehavior: EndpointAddrPortDependent, + Hairpinning: false, + MappingLifeTime: 30 * time.Second, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + loggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + + srcIP := net.ParseIP("192.168.0.2") + srcPort := 1234 + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + oic := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(dstIP, dstPort)) + log.Debugf("o-original : %s", oic.String()) + + oec, err := nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + + log.Debugf("o-translated: %s", oec.String()) + + // sending different (IP: 5.6.7.9) won't create a new mapping + oic2 := proto.newChunk(t, + proto.newAddr(srcIP, srcPort), + proto.newAddr(net.ParseIP("5.6.7.9"), 9000), + ) + oec2, err := nat.translateOutbound(oic2) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + log.Debugf("o-translated: %s", oec2.String()) + + oecIP, oecPort := natAddrIPPort(t, oec.SourceAddr()) + iec := proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort)) + + log.Debugf("i-original : %s", iec.String()) + + iic, err := nat.translateInbound(iec) + assert.Nil(t, err, "should succeed") + + log.Debugf("i-translated: %s", iic.String()) + + assert.Equal(t, oic.SourceAddr().String(), iic.DestinationAddr().String(), "should match") + + // packet with dest addr that does not exist in the mapping table + // will be dropped + iec = proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort+1)) + + _, err = nat.translateInbound(iec) + assert.NotNil(t, err, "should fail (dropped)") + + // packet from different port will be dropped (port-restricted-cone) + iec = proto.newChunk(t, proto.newAddr(dstIP, 7777), proto.newAddr(oecIP, oecPort)) + + _, err = nat.translateInbound(iec) + assert.NotNil(t, err, "should fail (dropped)") + + // packet from different addr will be dropped (restricted-cone) + iec = proto.newChunk(t, proto.newAddr(net.ParseIP("6.6.6.6"), dstPort), proto.newAddr(oecIP, oecPort)) + + _, err = nat.translateInbound(iec) + assert.NotNil(t, err, "should fail (dropped)") + }) - oic3 := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 6000, - }, - ) - - log.Debugf("o-original : %s", oic1.String()) - log.Debugf("o-original : %s", oic2.String()) - log.Debugf("o-original : %s", oic3.String()) - - oec1, err := nat.translateOutbound(oic1) - assert.Nil(t, err, "should succeed") - - oec2, err := nat.translateOutbound(oic2) - assert.Nil(t, err, "should succeed") - - oec3, err := nat.translateOutbound(oic3) - assert.Nil(t, err, "should succeed") - - assert.Equal(t, 2, len(nat.outboundMap), "should match") - assert.Equal(t, 2, len(nat.inboundMap), "should match") - - log.Debugf("o-translated: %s", oec1.String()) - log.Debugf("o-translated: %s", oec2.String()) - log.Debugf("o-translated: %s", oec3.String()) - - assert.NotEqual( - t, - oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert - oec2.(*chunkUDP).sourcePort, //nolint:forcetypeassert - "should not match", - ) - assert.Equal( - t, - oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert - oec3.(*chunkUDP).sourcePort, //nolint:forcetypeassert - "should match", - ) - }) + t.Run("symmetric NAT addr dependent mapping", func(t *testing.T) { //nolint:dupl + nat, err := newNAT(&natConfig{ + natType: NATType{ + MappingBehavior: EndpointAddrDependent, + FilteringBehavior: EndpointAddrDependent, + Hairpinning: false, + MappingLifeTime: 30 * time.Second, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + loggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + + srcIP := net.ParseIP("192.168.0.2") + srcPort := 1234 + oic1 := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(net.ParseIP("5.6.7.8"), 5678)) + oic2 := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(net.ParseIP("5.6.7.100"), 5678)) + oic3 := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(net.ParseIP("5.6.7.8"), 6000)) + + log.Debugf("o-original : %s", oic1.String()) + log.Debugf("o-original : %s", oic2.String()) + log.Debugf("o-original : %s", oic3.String()) + + oec1, err := nat.translateOutbound(oic1) + assert.Nil(t, err, "should succeed") + + oec2, err := nat.translateOutbound(oic2) + assert.Nil(t, err, "should succeed") + + oec3, err := nat.translateOutbound(oic3) + assert.Nil(t, err, "should succeed") + + assert.Equal(t, 2, len(nat.outboundMap), "should match") + assert.Equal(t, 2, len(nat.inboundMap), "should match") + + log.Debugf("o-translated: %s", oec1.String()) + log.Debugf("o-translated: %s", oec2.String()) + log.Debugf("o-translated: %s", oec3.String()) + + _, p1 := natAddrIPPort(t, oec1.SourceAddr()) + _, p2 := natAddrIPPort(t, oec2.SourceAddr()) + _, p3 := natAddrIPPort(t, oec3.SourceAddr()) + assert.NotEqual(t, p1, p2, "should not match") + assert.Equal(t, p1, p3, "should match") + }) - t.Run("symmetric NAT port dependent mapping", func(t *testing.T) { //nolint:dupl - nat, err := newNAT(&natConfig{ - natType: NATType{ - MappingBehavior: EndpointAddrPortDependent, - FilteringBehavior: EndpointAddrPortDependent, - Hairpinning: false, - MappingLifeTime: 30 * time.Second, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - loggerFactory: loggerFactory, + t.Run("symmetric NAT port dependent mapping", func(t *testing.T) { //nolint:dupl + nat, err := newNAT(&natConfig{ + natType: NATType{ + MappingBehavior: EndpointAddrPortDependent, + FilteringBehavior: EndpointAddrPortDependent, + Hairpinning: false, + MappingLifeTime: 30 * time.Second, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + loggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + + srcIP := net.ParseIP("192.168.0.2") + srcPort := 1234 + oic1 := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(net.ParseIP("5.6.7.8"), 5678)) + oic2 := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(net.ParseIP("5.6.7.100"), 5678)) + oic3 := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(net.ParseIP("5.6.7.8"), 6000)) + + log.Debugf("o-original : %s", oic1.String()) + log.Debugf("o-original : %s", oic2.String()) + log.Debugf("o-original : %s", oic3.String()) + + oec1, err := nat.translateOutbound(oic1) + assert.Nil(t, err, "should succeed") + + oec2, err := nat.translateOutbound(oic2) + assert.Nil(t, err, "should succeed") + + oec3, err := nat.translateOutbound(oic3) + assert.Nil(t, err, "should succeed") + + assert.Equal(t, 3, len(nat.outboundMap), "should match") + assert.Equal(t, 3, len(nat.inboundMap), "should match") + + log.Debugf("o-translated: %s", oec1.String()) + log.Debugf("o-translated: %s", oec2.String()) + log.Debugf("o-translated: %s", oec3.String()) + + _, p1 := natAddrIPPort(t, oec1.SourceAddr()) + _, p2 := natAddrIPPort(t, oec2.SourceAddr()) + _, p3 := natAddrIPPort(t, oec3.SourceAddr()) + assert.NotEqual(t, p1, p2, "should not match") + assert.NotEqual(t, p1, p3, "should match") + }) }) - assert.NoError(t, err, "should succeed") - - oic1 := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - }, - ) - - oic2 := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.100"), - Port: 5678, - }, - ) - - oic3 := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 6000, - }, - ) - - log.Debugf("o-original : %s", oic1.String()) - log.Debugf("o-original : %s", oic2.String()) - log.Debugf("o-original : %s", oic3.String()) - - oec1, err := nat.translateOutbound(oic1) - assert.Nil(t, err, "should succeed") - - oec2, err := nat.translateOutbound(oic2) - assert.Nil(t, err, "should succeed") - - oec3, err := nat.translateOutbound(oic3) - assert.Nil(t, err, "should succeed") - - assert.Equal(t, 3, len(nat.outboundMap), "should match") - assert.Equal(t, 3, len(nat.inboundMap), "should match") - - log.Debugf("o-translated: %s", oec1.String()) - log.Debugf("o-translated: %s", oec2.String()) - log.Debugf("o-translated: %s", oec3.String()) - - assert.NotEqual( - t, - oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert - oec2.(*chunkUDP).sourcePort, //nolint:forcetypeassert - "should not match", - ) - assert.NotEqual( - t, - oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert - oec3.(*chunkUDP).sourcePort, //nolint:forcetypeassert - "should match", - ) - }) + } } func TestNATMappingTimeout(t *testing.T) { @@ -542,278 +403,213 @@ func TestNATMappingTimeout(t *testing.T) { log := loggerFactory.NewLogger("test") t.Run("refresh on outbound", func(t *testing.T) { - nat, err := newNAT(&natConfig{ - natType: NATType{ - MappingBehavior: EndpointIndependent, - FilteringBehavior: EndpointIndependent, - Hairpinning: false, - MappingLifeTime: 100 * time.Millisecond, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - loggerFactory: loggerFactory, - }) - assert.NoError(t, err, "should succeed") - - src := &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - } - dst := &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, + for _, proto := range natTestProtos() { + proto := proto + t.Run(proto.name, func(t *testing.T) { + nat, err := newNAT(&natConfig{ + natType: NATType{ + MappingBehavior: EndpointIndependent, + FilteringBehavior: EndpointIndependent, + Hairpinning: false, + MappingLifeTime: 100 * time.Millisecond, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + loggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + + srcIP := net.ParseIP("192.168.0.2") + srcPort := 1234 + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + oic := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(dstIP, dstPort)) + + oec, err := nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + + log.Debugf("o-original : %s", oic.String()) + log.Debugf("o-translated: %s", oec.String()) + + mapped := oec.SourceAddr().String() + + time.Sleep(75 * time.Millisecond) + + // refresh + oec, err = nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + + log.Debugf("o-original : %s", oic.String()) + log.Debugf("o-translated: %s", oec.String()) + + assert.Equal(t, mapped, oec.SourceAddr().String(), "mapped addr should match") + + // sleep long enough for the mapping to expire + time.Sleep(125 * time.Millisecond) + + // refresh after expiration + oec, err = nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + assert.NotEqual(t, mapped, oec.SourceAddr().String(), "mapped addr should not match") + }) } - - oic := newChunkUDP(src, dst) - - oec, err := nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - - log.Debugf("o-original : %s", oic.String()) - log.Debugf("o-translated: %s", oec.String()) - - // record mapped addr - mapped := oec.(*chunkUDP).SourceAddr().String() //nolint:forcetypeassert - - time.Sleep(75 * time.Millisecond) - - // refresh - oec, err = nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - - log.Debugf("o-original : %s", oic.String()) - log.Debugf("o-translated: %s", oec.String()) - - assert.Equal(t, mapped, oec.(*chunkUDP).SourceAddr().String(), "mapped addr should match") //nolint:forcetypeassert - - // sleep long enough for the mapping to expire - time.Sleep(125 * time.Millisecond) - - // refresh after expiration - oec, err = nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - - log.Debugf("o-original : %s", oic.String()) - log.Debugf("o-translated: %s", oec.String()) - - assert.NotEqual( - t, - mapped, - oec.(*chunkUDP).SourceAddr().String(), //nolint:forcetypeassert - "mapped addr should not match", - ) }) t.Run("outbound detects timeout", func(t *testing.T) { - nat, err := newNAT(&natConfig{ - natType: NATType{ - MappingBehavior: EndpointIndependent, - FilteringBehavior: EndpointIndependent, - Hairpinning: false, - MappingLifeTime: 100 * time.Millisecond, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - loggerFactory: loggerFactory, - }) - assert.NoError(t, err, "should succeed") - - src := &net.UDPAddr{ - IP: net.ParseIP("192.168.0.2"), - Port: 1234, - } - dst := &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, + for _, proto := range natTestProtos() { + proto := proto + t.Run(proto.name, func(t *testing.T) { + nat, err := newNAT(&natConfig{ + natType: NATType{ + MappingBehavior: EndpointIndependent, + FilteringBehavior: EndpointIndependent, + Hairpinning: false, + MappingLifeTime: 100 * time.Millisecond, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + loggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + + srcIP := net.ParseIP("192.168.0.2") + srcPort := 1234 + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + oic := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(dstIP, dstPort)) + + oec, err := nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Equal(t, 1, len(nat.outboundMap), "should match") + assert.Equal(t, 1, len(nat.inboundMap), "should match") + + log.Debugf("o-original : %s", oic.String()) + log.Debugf("o-translated: %s", oec.String()) + + // sleep long enough for the mapping to expire + time.Sleep(125 * time.Millisecond) + + oecIP, oecPort := natAddrIPPort(t, oec.SourceAddr()) + iec := proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort)) + log.Debugf("i-original : %s", iec.String()) + + iic, err := nat.translateInbound(iec) + assert.NotNil(t, err, "should drop") + assert.Nil(t, iic, "should be nil") + assert.Empty(t, nat.outboundMap, "should have no binding") + assert.Empty(t, nat.inboundMap, "should have no binding") + }) } - - oic := newChunkUDP(src, dst) - - oec, err := nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Equal(t, 1, len(nat.outboundMap), "should match") - assert.Equal(t, 1, len(nat.inboundMap), "should match") - - log.Debugf("o-original : %s", oic.String()) - log.Debugf("o-translated: %s", oec.String()) - - // sleep long enough for the mapping to expire - time.Sleep(125 * time.Millisecond) - - //nolint:forcetypeassert - iec := newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - log.Debugf("i-original : %s", iec.String()) - - _, err = nat.translateInbound(iec) - assert.NotNil(t, err, "should drop") - assert.Empty(t, nat.outboundMap, "should have no binding") - assert.Empty(t, nat.inboundMap, "should have no binding") }) } -func TestNAT1To1Behavior(t *testing.T) { +func TestNAT1To1Behavior(t *testing.T) { // nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("1:1 NAT with one mapping", func(t *testing.T) { - nat, err := newNAT(&natConfig{ - natType: NATType{ - Mode: NATModeNAT1To1, - }, - mappedIPs: []net.IP{net.ParseIP(demoIP)}, - localIPs: []net.IP{net.ParseIP("10.0.0.1")}, - loggerFactory: loggerFactory, - }) - if !assert.NoError(t, err, "should succeed") { - return - } - - src := &net.UDPAddr{ - IP: net.ParseIP("10.0.0.1"), - Port: 1234, - } - dst := &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, + for _, proto := range natTestProtos() { + proto := proto + t.Run(proto.name, func(t *testing.T) { + nat, err := newNAT(&natConfig{ + natType: NATType{ + Mode: NATModeNAT1To1, + }, + mappedIPs: []net.IP{net.ParseIP(demoIP)}, + localIPs: []net.IP{net.ParseIP("10.0.0.1")}, + loggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + + srcIP := net.ParseIP("10.0.0.1") + srcPort := 1234 + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + oic := proto.newChunk(t, proto.newAddr(srcIP, srcPort), proto.newAddr(dstIP, dstPort)) + + oec, err := nat.translateOutbound(oic) + assert.Nil(t, err, "should succeed") + assert.Empty(t, nat.outboundMap, "should match") + assert.Empty(t, nat.inboundMap, "should match") + + log.Debugf("o-original : %s", oic.String()) + log.Debugf("o-translated: %s", oec.String()) + assert.Equal(t, "1.2.3.4:1234", oec.SourceAddr().String(), "should match") + + oecIP, oecPort := natAddrIPPort(t, oec.SourceAddr()) + iec := proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(oecIP, oecPort)) + log.Debugf("i-original : %s", iec.String()) + + iic, err := nat.translateInbound(iec) + assert.Nil(t, err, "should succeed") + log.Debugf("i-translated: %s", iic.String()) + assert.Equal(t, oic.SourceAddr().String(), iic.DestinationAddr().String(), "should match") + }) } - - oic := newChunkUDP(src, dst) - - oec, err := nat.translateOutbound(oic) - assert.Nil(t, err, "should succeed") - assert.Empty(t, nat.outboundMap, "should match") - assert.Empty(t, nat.inboundMap, "should match") - - log.Debugf("o-original : %s", oic.String()) - log.Debugf("o-translated: %s", oec.String()) - - assert.Equal(t, "1.2.3.4:1234", oec.SourceAddr().String(), "should match") - - //nolint:forcetypeassert - iec := newChunkUDP( - &net.UDPAddr{ - IP: dst.IP, - Port: dst.Port, - }, - &net.UDPAddr{ - IP: oec.(*chunkUDP).sourceIP, - Port: oec.(*chunkUDP).sourcePort, - }, - ) - - log.Debugf("i-original : %s", iec.String()) - - iic, err := nat.translateInbound(iec) - assert.Nil(t, err, "should succeed") - - log.Debugf("i-translated: %s", iic.String()) - - assert.Equal(t, - oic.SourceAddr().String(), - iic.DestinationAddr().String(), - "should match") }) t.Run("1:1 NAT with more than one mapping", func(t *testing.T) { - nat, err := newNAT(&natConfig{ - natType: NATType{ - Mode: NATModeNAT1To1, - }, - mappedIPs: []net.IP{ - net.ParseIP(demoIP), - net.ParseIP("1.2.3.5"), - }, - localIPs: []net.IP{ - net.ParseIP("10.0.0.1"), - net.ParseIP("10.0.0.2"), - }, - loggerFactory: loggerFactory, - }) - if !assert.NoError(t, err, "should succeed") { - return - } - - // outbound translation - - before := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("10.0.0.1"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, + for _, proto := range natTestProtos() { + proto := proto + t.Run(proto.name, func(t *testing.T) { + nat, err := newNAT(&natConfig{ + natType: NATType{ + Mode: NATModeNAT1To1, + }, + mappedIPs: []net.IP{ + net.ParseIP(demoIP), + net.ParseIP("1.2.3.5"), + }, + localIPs: []net.IP{ + net.ParseIP("10.0.0.1"), + net.ParseIP("10.0.0.2"), + }, + loggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + + // outbound translation + oic := proto.newChunk(t, proto.newAddr(net.ParseIP("10.0.0.1"), 1234), proto.newAddr(dstIP, dstPort)) + oec, err := nat.translateOutbound(oic) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.Equal(t, "1.2.3.4:1234", oec.SourceAddr().String(), "should match") + + oic = proto.newChunk(t, proto.newAddr(net.ParseIP("10.0.0.2"), 1234), proto.newAddr(dstIP, dstPort)) + oec, err = nat.translateOutbound(oic) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.Equal(t, "1.2.3.5:1234", oec.SourceAddr().String(), "should match") + + // inbound translation + iec := proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(net.ParseIP(demoIP), 2525)) + iic, err := nat.translateInbound(iec) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.Equal(t, "10.0.0.1:2525", iic.DestinationAddr().String(), "should match") + + iec = proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(net.ParseIP("1.2.3.5"), 9847)) + iic, err = nat.translateInbound(iec) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.Equal(t, "10.0.0.2:9847", iic.DestinationAddr().String(), "should match") }) - - after, err := nat.translateOutbound(before) - if !assert.NoError(t, err, "should succeed") { - return } - assert.Equal(t, "1.2.3.4:1234", after.SourceAddr().String(), "should match") - - before = newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("10.0.0.2"), - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - }) - - after, err = nat.translateOutbound(before) - if !assert.NoError(t, err, "should succeed") { - return - } - assert.Equal(t, "1.2.3.5:1234", after.SourceAddr().String(), "should match") - - // inbound translation - - before = newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - }, - &net.UDPAddr{ - IP: net.ParseIP(demoIP), - Port: 2525, - }) - - after, err = nat.translateInbound(before) - if !assert.NoError(t, err, "should succeed") { - return - } - assert.Equal(t, "10.0.0.1:2525", after.DestinationAddr().String(), "should match") - - before = newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - }, - &net.UDPAddr{ - IP: net.ParseIP("1.2.3.5"), - Port: 9847, - }) - - after, err = nat.translateInbound(before) - if !assert.NoError(t, err, "should succeed") { - return - } - assert.Equal(t, "10.0.0.2:9847", after.DestinationAddr().String(), "should match") }) t.Run("1:1 NAT failure", func(t *testing.T) { @@ -858,35 +654,22 @@ func TestNAT1To1Behavior(t *testing.T) { }) assert.NoError(t, err, "should succeed") - before := newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("10.0.0.2"), // no external mapping for this - Port: 1234, - }, - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, + for _, proto := range natTestProtos() { + proto := proto + t.Run(proto.name, func(t *testing.T) { + dstIP := net.ParseIP("5.6.7.8") + dstPort := 5678 + oic := proto.newChunk(t, proto.newAddr(net.ParseIP("10.0.0.2"), 1234), proto.newAddr(dstIP, dstPort)) + oec, err := nat.translateOutbound(oic) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.Nil(t, oec, "should be nil") + + iec := proto.newChunk(t, proto.newAddr(dstIP, dstPort), proto.newAddr(net.ParseIP("10.0.0.2"), 1234)) + _, err = nat.translateInbound(iec) + assert.Error(t, err, "should fail") }) - - after, err := nat.translateOutbound(before) - if !assert.NoError(t, err, "should succeed") { - return - } - if !assert.Nil(t, after, "should be nil") { - return } - - before = newChunkUDP( - &net.UDPAddr{ - IP: net.ParseIP("5.6.7.8"), - Port: 5678, - }, - &net.UDPAddr{ - IP: net.ParseIP("10.0.0.2"), // no local mapping for this - Port: 1234, - }) - - _, err = nat.translateInbound(before) - assert.Error(t, err, "should fail") }) } diff --git a/vnet/net.go b/vnet/net.go index 589fe4bb..14ed27f1 100644 --- a/vnet/net.go +++ b/vnet/net.go @@ -21,6 +21,8 @@ const ( lo0String = "lo0String" udp = "udp" udp4 = "udp4" + tcp = "tcp" + tcp4 = "tcp4" ) var ( @@ -48,11 +50,13 @@ func newMACAddress() net.HardwareAddr { // Net represents a local network stack equivalent to a set of layers from NIC // up to the transport (UDP / TCP) layer. type Net struct { - interfaces []*transport.Interface // read-only - staticIPs []net.IP // read-only - router *Router // read-only - udpConns *udpConnMap // read-only - mutex sync.RWMutex + interfaces []*transport.Interface // read-only + staticIPs []net.IP // read-only + router *Router // read-only + udpConns *udpConnMap // read-only + tcpListeners *tcpListenerMap // read-only + tcpConns *tcpConnMap // read-only + mutex sync.RWMutex } // Compile-time assertion. @@ -199,14 +203,57 @@ func (v *Net) RemoveAddress(ifName string, ip net.IP) error { return nil } -func (v *Net) onInboundChunk(c Chunk) { - v.mutex.Lock() - defer v.mutex.Unlock() +func (v *Net) onInboundChunk(chunk Chunk) { + switch chunk.Network() { + case udp: + v.mutex.Lock() + conn, ok := v.udpConns.find(chunk.DestinationAddr()) + v.mutex.Unlock() + if ok { + conn.onInboundChunk(chunk) + } + + return + + case tcp: + tcpChunk, ok := chunk.(*chunkTCP) + if !ok { + return + } + + // Lookups must be protected, but handlers may re-enter Net via write(). + v.mutex.Lock() + conn, connOK := v.tcpConns.findByChunk(tcpChunk) + if connOK { + v.mutex.Unlock() + conn.onInboundChunk(tcpChunk) - if c.Network() == udp { - if conn, ok := v.udpConns.find(c.DestinationAddr()); ok { - conn.onInboundChunk(c) + return } + + // New connection attempt (SYN) + if tcpChunk.flags&tcpSYN != 0 && tcpChunk.flags&tcpACK == 0 { + dstAddr := tcpChunk.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + l, ok := v.tcpListeners.find(dstAddr) + v.mutex.Unlock() + if ok { + l.onInboundSYN(tcpChunk) + + return + } + } else { + v.mutex.Unlock() + } + + // No listener/conn for this tuple; send RST back. + dstAddr := tcpChunk.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + srcAddr := tcpChunk.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert + _ = v.write(newChunkTCP(dstAddr, srcAddr, tcpRST)) + + return + + default: + return } } @@ -237,7 +284,7 @@ func (v *Net) _dialUDP(network string, locAddr, remAddr *net.UDPAddr) (transport if locAddr.Port == 0 { // choose randomly from the range between 5000 and 5999 - port, err := v.assignPort(locAddr.IP, 5000, 5999) + port, err := v.assignUDPPort(locAddr.IP, 5000, 5999) if err != nil { return nil, &net.OpError{ Op: "listen", @@ -300,20 +347,38 @@ func (v *Net) DialUDP(network string, locAddr, remAddr *net.UDPAddr) (transport. // Dial connects to the address on the named network. func (v *Net) Dial(network string, address string) (net.Conn, error) { - v.mutex.Lock() - defer v.mutex.Unlock() + switch network { + case udp, udp4: + remAddr, err := v.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } - remAddr, err := v.ResolveUDPAddr(network, address) - if err != nil { - return nil, err - } + v.mutex.Lock() + srcIP := v.determineSourceIP(nil, remAddr.IP) + v.mutex.Unlock() - // Determine source address - srcIP := v.determineSourceIP(nil, remAddr.IP) + locAddr := &net.UDPAddr{IP: srcIP, Port: 0} - locAddr := &net.UDPAddr{IP: srcIP, Port: 0} + return v.DialUDP(network, locAddr, remAddr) - return v._dialUDP(network, locAddr, remAddr) + case tcp, tcp4: + remAddr, err := v.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + v.mutex.Lock() + srcIP := v.determineSourceIP(nil, remAddr.IP) + v.mutex.Unlock() + + locAddr := &net.TCPAddr{IP: srcIP, Port: 0} + + return v.DialTCP(network, locAddr, remAddr) + + default: + return nil, fmt.Errorf("%w %s", errUnknownNetwork, network) + } } // ResolveIPAddr returns an address of IP end point. @@ -376,7 +441,7 @@ func (v *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { // ResolveTCPAddr returns an address of TCP end point. func (v *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { - if network != udp && network != "udp4" { + if network != tcp && network != tcp4 { return nil, fmt.Errorf("%w %s", errUnknownNetwork, network) } @@ -395,17 +460,18 @@ func (v *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { return nil, errInvalidPortNumber } - udpAddr := &net.TCPAddr{ + tcpAddr := &net.TCPAddr{ IP: ipAddr.IP, Zone: ipAddr.Zone, Port: port, } - return udpAddr, nil + return tcpAddr, nil } func (v *Net) write(chunk Chunk) error { - if chunk.Network() == udp { //nolint:nestif + switch chunk.Network() { + case udp: if udp, ok := chunk.(*chunkUDP); ok { if chunk.getDestinationIP().IsLoopback() { if conn, ok := v.udpConns.find(udp.DestinationAddr()); ok { @@ -417,6 +483,20 @@ func (v *Net) write(chunk Chunk) error { } else { return errUnexpectedTypeSwitchFailure } + + case tcp: + if tcp, ok := chunk.(*chunkTCP); ok { + if chunk.getDestinationIP().IsLoopback() { + v.onInboundChunk(tcp) + + return nil + } + } else { + return errUnexpectedTypeSwitchFailure + } + + default: + return fmt.Errorf("%w: %s", errUnexpectedNetwork, chunk.Network()) } if v.router == nil { @@ -429,9 +509,15 @@ func (v *Net) write(chunk Chunk) error { } func (v *Net) onClosed(addr net.Addr) { - if addr.Network() == udp { + switch addr.Network() { + case udp: //nolint:errcheck v.udpConns.delete(addr) // #nosec + case tcp: + //nolint:errcheck + v.tcpConns.deleteByAddr(addr) // #nosec + default: + // do nothing } } @@ -560,7 +646,7 @@ func (v *Net) allocateLocalAddr(ip net.IP, port int) error { } // caller must hold the mutex. -func (v *Net) assignPort(ip net.IP, start, end int) (int, error) { +func (v *Net) assignUDPPort(ip net.IP, start, end int) (int, error) { // choose randomly from the range between start and end (inclusive) if end < start { return -1, errEndPortLessThanStart @@ -580,6 +666,52 @@ func (v *Net) assignPort(ip net.IP, start, end int) (int, error) { return -1, errPortSpaceExhausted } +// caller must hold the mutex. +func (v *Net) assignTCPListenerPort(ip net.IP, start, end int) (int, error) { + if end < start { + return -1, errEndPortLessThanStart + } + + space := end + 1 - start + offset := rand.Intn(space) //nolint:gosec + for i := 0; i < space; i++ { + port := ((offset + i) % space) + start + addr := &net.TCPAddr{IP: ip, Port: port} + if _, ok := v.tcpListeners.find(addr); ok { + continue + } + + return port, nil + } + + return -1, errPortSpaceExhausted +} + +// caller must hold the mutex. +func (v *Net) assignTCPPort(ip net.IP, start, end int) (int, error) { + if end < start { + return -1, errEndPortLessThanStart + } + + space := end + 1 - start + offset := rand.Intn(space) //nolint:gosec + for i := 0; i < space; i++ { + port := ((offset + i) % space) + start + // For simplicity, don't reuse a local port if any listener exists on that port. + if _, ok := v.tcpListeners.find(&net.TCPAddr{IP: ip, Port: port}); ok { + continue + } + // Also avoid if any connection is already using this local port. + if v.tcpConns.portInUse(ip, port) { + continue + } + + return port, nil + } + + return -1, errPortSpaceExhausted +} + func (v *Net) getStaticIPs() []net.IP { return v.staticIPs } @@ -626,20 +758,128 @@ func NewNet(config *NetConfig) (*Net, error) { } return &Net{ - interfaces: []*transport.Interface{lo0, eth0}, - staticIPs: staticIPs, - udpConns: newUDPConnMap(), + interfaces: []*transport.Interface{lo0, eth0}, + staticIPs: staticIPs, + udpConns: newUDPConnMap(), + tcpListeners: newTCPListenerMap(), + tcpConns: newTCPConnMap(), }, nil } // DialTCP acts like Dial for TCP networks. -func (v *Net) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { - return nil, transport.ErrNotSupported +func (v *Net) DialTCP(network string, locAddr, remAddr *net.TCPAddr) (transport.TCPConn, error) { //nolint:cyclop + if network != tcp && network != tcp4 { + return nil, fmt.Errorf("%w: %s", errUnexpectedNetwork, network) + } + + if remAddr == nil { + return nil, &net.OpError{Op: "dial", Net: network, Addr: nil, Err: errNoRemAddr} + } + if remAddr.IP == nil { + remAddr.IP = net.IPv4zero + } + + if locAddr == nil { + locAddr = &net.TCPAddr{IP: net.IPv4zero, Port: 0} + } else if locAddr.IP == nil { + locAddr.IP = net.IPv4zero + } + + v.mutex.Lock() + + // determine local IP if unspecified + locAddr.IP = v.determineSourceIP(locAddr.IP, remAddr.IP) + if locAddr.IP == nil { + v.mutex.Unlock() + + return nil, errLocAddr + } + + // validate address. do we have that address? + if !v.hasIPAddr(locAddr.IP) { + v.mutex.Unlock() + + return nil, &net.OpError{Op: "dial", Net: network, Addr: locAddr, Err: fmt.Errorf("bind: %w", + errCantAssignRequestedAddr)} + } + + if locAddr.Port == 0 { + port, err := v.assignTCPPort(locAddr.IP, 5000, 5999) + if err != nil { + v.mutex.Unlock() + + return nil, &net.OpError{Op: "dial", Net: network, Addr: locAddr, Err: err} + } + locAddr.Port = port + } + + conn, err := newTCPConn(locAddr, remAddr, v, nil) + if err != nil { + v.mutex.Unlock() + + return nil, err + } + + if err := v.tcpConns.insert(conn); err != nil { + v.mutex.Unlock() + + return nil, &net.OpError{Op: "dial", Net: network, Addr: locAddr, Err: fmt.Errorf("bind: %w", err)} + } + v.mutex.Unlock() + + if err := conn.startClientHandshake(); err != nil { + _ = v.tcpConns.deleteConn(conn) + + return nil, err + } + + if err := conn.waitEstablished(); err != nil { + _ = conn.Close() + + return nil, err + } + + return conn, nil } // ListenTCP acts like Listen for TCP networks. -func (v *Net) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { - return nil, transport.ErrNotSupported +func (v *Net) ListenTCP(network string, locAddr *net.TCPAddr) (transport.TCPListener, error) { //nolint:cyclop + v.mutex.Lock() + defer v.mutex.Unlock() + + if network != tcp && network != tcp4 { + return nil, fmt.Errorf("%w: %s", errUnexpectedNetwork, network) + } + + if locAddr == nil { + locAddr = &net.TCPAddr{IP: net.IPv4zero, Port: 0} + } else if locAddr.IP == nil { + locAddr.IP = net.IPv4zero + } + + if !v.hasIPAddr(locAddr.IP) { + return nil, &net.OpError{Op: "listen", Net: network, Addr: locAddr, Err: fmt.Errorf("bind: %w", + errCantAssignRequestedAddr)} + } + + if locAddr.Port == 0 { + port, err := v.assignTCPListenerPort(locAddr.IP, 5000, 5999) + if err != nil { + return nil, &net.OpError{Op: "listen", Net: network, Addr: locAddr, Err: err} + } + locAddr.Port = port + } + + l, err := newTCPListener(locAddr, v) + if err != nil { + return nil, err + } + + if err := v.tcpListeners.insert(l); err != nil { + return nil, &net.OpError{Op: "listen", Net: network, Addr: locAddr, Err: fmt.Errorf("bind: %w", err)} + } + + return l, nil } // CreateDialer creates an instance of vnet.Dialer. diff --git a/vnet/net_test.go b/vnet/net_test.go index 0802d452..af35e65f 100644 --- a/vnet/net_test.go +++ b/vnet/net_test.go @@ -6,8 +6,10 @@ package vnet import ( "context" "fmt" + "io" "net" "testing" + "time" "github.com/pion/logging" "github.com/pion/transport/v4" @@ -176,7 +178,7 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx } }) - t.Run("assignPort()", func(t *testing.T) { + t.Run("assignUDPPort()", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return @@ -201,11 +203,11 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx }) // attempt to assign port with start > end should fail - _, err = nw.assignPort(net.ParseIP(addr), 3000, 2999) + _, err = nw.assignUDPPort(net.ParseIP(addr), 3000, 2999) assert.NotNil(t, err, "should fail") for i := 0; i < space; i++ { - port, err2 := nw.assignPort(net.ParseIP(addr), start, end) + port, err2 := nw.assignUDPPort(net.ParseIP(addr), start, end) assert.NoError(t, err2, "should succeed") log.Debugf("[%d] got port: %d", i, port) @@ -221,10 +223,50 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx assert.Equal(t, space, nw.udpConns.size(), "should match") // attempt to assign again should fail - _, err = nw.assignPort(net.ParseIP(addr), start, end) + _, err = nw.assignUDPPort(net.ParseIP(addr), start, end) assert.NotNil(t, err, "should fail") }) + t.Run("Port allocation independent for UDP/TCP", func(t *testing.T) { + nw, err := NewNet(&NetConfig{}) + if !assert.NoError(t, err, "should succeed") { + return + } + + ip := net.ParseIP("127.0.0.1") + + // If TCP is already using a port, UDP allocation should still be able to pick it. + tcpL, err := nw.ListenTCP(tcp, &net.TCPAddr{IP: ip, Port: 5010}) + if !assert.NoError(t, err, "should succeed") { + return + } + defer func() { _ = tcpL.Close() }() + + udpPort, err := nw.assignUDPPort(ip, 5010, 5010) + assert.NoError(t, err, "should succeed") + assert.Equal(t, 5010, udpPort, "should match") + + assert.NoError(t, tcpL.Close(), "should succeed") + + // If UDP is already using a port, TCP allocation should still be able to pick it. + udpC, err := nw.ListenPacket(udp, "127.0.0.1:5011") + if !assert.NoError(t, err, "should succeed") { + return + } + defer func() { _ = udpC.Close() }() + + tcpPort, err := nw.assignTCPPort(ip, 5011, 5011) + assert.NoError(t, err, "should succeed") + assert.Equal(t, 5011, tcpPort, "should match") + + // And explicit binds should also be independent. + tcpL2, err := nw.ListenTCP(tcp, &net.TCPAddr{IP: ip, Port: 5011}) + assert.NoError(t, err, "should succeed") + if tcpL2 != nil { + _ = tcpL2.Close() + } + }) + t.Run("determineSourceIP()", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { @@ -283,6 +325,20 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx assert.Equal(t, 1234, udpAddr.Port, "should match") }) + t.Run("ResolveTCPAddr", func(t *testing.T) { + nw, err := NewNet(&NetConfig{}) + if !assert.NoError(t, err, "should succeed") { + return + } + + tcpAddr, err := nw.ResolveTCPAddr(tcp, "localhost:1234") + if !assert.NoError(t, err, "should succeed") { + return + } + assert.Equal(t, "127.0.0.1", tcpAddr.IP.String(), "should match") + assert.Equal(t, 1234, tcpAddr.Port, "should match") + }) + t.Run("UDPLoopback", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { @@ -309,6 +365,101 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx assert.Empty(t, nw.udpConns.size(), "should match") }) + t.Run("TCPLoopback", func(t *testing.T) { + nw, err := NewNet(&NetConfig{}) + if !assert.NoError(t, err, "should succeed") { + return + } + + listener, err := nw.ListenTCP(tcp, &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if !assert.NoError(t, err, "should succeed") { + return + } + defer func() { _ = listener.Close() }() + _ = listener.SetDeadline(time.Now().Add(2 * time.Second)) + + acceptedCh := make(chan net.Conn, 1) + go func() { + c, err2 := listener.Accept() + if err2 != nil { + close(acceptedCh) + + return + } + acceptedCh <- c + }() + + client, err := nw.Dial(tcp, listener.Addr().String()) + if !assert.NoError(t, err, "should succeed") { + return + } + defer func() { _ = client.Close() }() + + var server net.Conn + select { + case server = <-acceptedCh: + case <-time.After(2 * time.Second): + assert.Fail(t, "accept timed out") + + return + } + if !assert.NotNil(t, server, "should accept") { + return + } + defer func() { _ = server.Close() }() + + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + _ = server.SetDeadline(time.Now().Add(2 * time.Second)) + + msg := "PING!" + + serverReadDone := make(chan error, 1) + go func() { + buf := make([]byte, len(msg)) + _, err2 := io.ReadFull(server, buf) + if err2 == nil && string(buf) != msg { + err2 = fmt.Errorf("unexpected payload: %q", string(buf)) // nolint:err113 + } + serverReadDone <- err2 + }() + + n, err := client.Write([]byte(msg)) + assert.NoError(t, err, "should succeed") + assert.Equal(t, len(msg), n, "should match") + + select { + case err2 := <-serverReadDone: + assert.NoError(t, err2, "should succeed") + case <-time.After(2 * time.Second): + assert.Fail(t, "server read timed out") + + return + } + + clientReadDone := make(chan error, 1) + go func() { + buf := make([]byte, len(msg)) + _, err2 := io.ReadFull(client, buf) + if err2 == nil && string(buf) != msg { + err2 = fmt.Errorf("unexpected payload: %q", string(buf)) // nolint:err113 + } + clientReadDone <- err2 + }() + + n, err = server.Write([]byte(msg)) + assert.NoError(t, err, "should succeed") + assert.Equal(t, len(msg), n, "should match") + + select { + case err2 := <-clientReadDone: + assert.NoError(t, err2, "should succeed") + case <-time.After(2 * time.Second): + assert.Fail(t, "client read timed out") + + return + } + }) + t.Run("ListenPacket random port", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { @@ -407,6 +558,57 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx assert.Empty(t, nw.udpConns.size(), "should match") }) + t.Run("Dial (TCP) lo0", func(t *testing.T) { + nw, err := NewNet(&NetConfig{}) + if !assert.NoError(t, err, "should succeed") { + return + } + + listener, err := nw.ListenTCP(tcp, &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234}) + if !assert.NoError(t, err, "should succeed") { + return + } + defer func() { _ = listener.Close() }() + _ = listener.SetDeadline(time.Now().Add(2 * time.Second)) + + acceptedCh := make(chan net.Conn, 1) + go func() { + c, err2 := listener.Accept() + if err2 != nil { + close(acceptedCh) + + return + } + acceptedCh <- c + }() + + conn, err := nw.Dial(tcp, "127.0.0.1:1234") + assert.NoError(t, err, "should succeed") + defer func() { _ = conn.Close() }() + + var server net.Conn + select { + case server = <-acceptedCh: + case <-time.After(2 * time.Second): + assert.Fail(t, "accept timed out") + + return + } + if server != nil { + _ = server.Close() + } + + laddr := conn.LocalAddr() + log.Debugf("laddr: %s", laddr.String()) + + raddr := conn.RemoteAddr() + log.Debugf("raddr: %s", raddr.String()) + + assert.Equal(t, "127.0.0.1", laddr.(*net.TCPAddr).IP.String(), "should match") //nolint:forcetypeassert + assert.True(t, laddr.(*net.TCPAddr).Port != 0, "should match") //nolint:forcetypeassert + assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") + }) + t.Run("Dial (UDP) eth0", func(t *testing.T) { wan, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", @@ -438,6 +640,80 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx assert.Empty(t, nw.udpConns.size(), "should match") }) + t.Run("Dial (TCP) eth0", func(t *testing.T) { + wan, err := NewRouter(&RouterConfig{ + CIDR: "1.2.3.0/24", + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + + net1, err := NewNet(&NetConfig{}) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.NoError(t, wan.AddNet(net1), "should succeed") + ip1, err := getIPAddr(net1) + assert.NoError(t, err, "should succeed") + + net2, err := NewNet(&NetConfig{}) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.NoError(t, wan.AddNet(net2), "should succeed") + ip2, err := getIPAddr(net2) + assert.NoError(t, err, "should succeed") + + assert.NoError(t, wan.Start(), "should succeed") + defer func() { _ = wan.Stop() }() + + listener, err := net2.ListenTCP(tcp, &net.TCPAddr{IP: net.ParseIP(ip2), Port: 0}) + if !assert.NoError(t, err, "should succeed") { + return + } + defer func() { _ = listener.Close() }() + _ = listener.SetDeadline(time.Now().Add(2 * time.Second)) + + acceptedCh := make(chan net.Conn, 1) + go func() { + c, err2 := listener.Accept() + if err2 != nil { + close(acceptedCh) + + return + } + acceptedCh <- c + }() + + conn, err := net1.Dial(tcp, listener.Addr().String()) + assert.NoError(t, err, "should succeed") + defer func() { _ = conn.Close() }() + + var server net.Conn + select { + case server = <-acceptedCh: + case <-time.After(2 * time.Second): + assert.Fail(t, "accept timed out") + + return + } + if server != nil { + _ = server.Close() + } + + laddr := conn.LocalAddr() + log.Debugf("laddr: %s", laddr.String()) + + raddr := conn.RemoteAddr() + log.Debugf("raddr: %s", raddr.String()) + + assert.Equal(t, ip1, laddr.(*net.TCPAddr).IP.String(), "should match") //nolint:forcetypeassert + assert.True(t, laddr.(*net.TCPAddr).Port != 0, "should match") //nolint:forcetypeassert + listenerPort := listener.Addr().(*net.TCPAddr).Port // nolint:forcetypeassert + assert.Equal(t, fmt.Sprintf("%s:%d", ip2, listenerPort), raddr.String(), "should match") + }) + t.Run("DialUDP", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { @@ -711,18 +987,67 @@ func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx assert.Empty(t, nw.udpConns.size(), "should match") }) + t.Run("Dialer (TCP)", func(t *testing.T) { + nw, err := NewNet(&NetConfig{}) + if !assert.NoError(t, err, "should succeed") { + return + } + + listener, err := nw.ListenTCP(tcp, &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if !assert.NoError(t, err, "should succeed") { + return + } + defer func() { _ = listener.Close() }() + _ = listener.SetDeadline(time.Now().Add(2 * time.Second)) + + acceptedCh := make(chan net.Conn, 1) + go func() { + c, err2 := listener.Accept() + if err2 != nil { + close(acceptedCh) + + return + } + acceptedCh <- c + }() + + dialer := nw.CreateDialer(&net.Dialer{ + LocalAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, + }) + + conn, err := dialer.Dial(tcp, listener.Addr().String()) + assert.NoError(t, err, "should succeed") + defer func() { _ = conn.Close() }() + + var server net.Conn + select { + case server = <-acceptedCh: + case <-time.After(2 * time.Second): + assert.Fail(t, "accept timed out") + + return + } + if server != nil { + _ = server.Close() + } + + laddr := conn.LocalAddr() + assert.Equal(t, "127.0.0.1", laddr.(*net.TCPAddr).IP.String(), "should match") //nolint:forcetypeassert + assert.True(t, laddr.(*net.TCPAddr).Port != 0, "should match") //nolint:forcetypeassert + }) + t.Run("Listen", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) assert.Nil(t, err, "should succeed") listenConfig := nw.CreateListenConfig(&net.ListenConfig{}) - listener, err := listenConfig.Listen(context.Background(), "tcp4", "127.0.0.1:1234") + listener, err := listenConfig.Listen(context.Background(), tcp4, "127.0.0.1:1234") assert.NoError(t, err, "should succeed") laddr := listener.Addr() log.Debugf("laddr: %s", laddr.String()) - conn, err := net.Dial("tcp4", "127.0.0.1:1234") //nolint:noctx + conn, err := net.Dial(tcp4, "127.0.0.1:1234") //nolint:noctx assert.NoError(t, err, "should succeed") raddr := conn.RemoteAddr() diff --git a/vnet/tcp_conn.go b/vnet/tcp_conn.go new file mode 100644 index 00000000..5b06843f --- /dev/null +++ b/vnet/tcp_conn.go @@ -0,0 +1,534 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "io" + "net" + "sync" + "time" + + "github.com/pion/transport/v4" +) + +type tcpState uint8 + +const ( + tcpStateInit tcpState = iota + tcpStateSynSent + tcpStateSynReceived + tcpStateEstablished + tcpStateClosed +) + +var ( + errConnectionReset = errors.New("connection reset") + errConnectionNotEstablished = errors.New("connection not established") +) + +// TCPConn implements transport.TCPConn. +type TCPConn struct { + locAddr *net.TCPAddr + remAddr *net.TCPAddr + obs connObserver + + mu sync.Mutex + state tcpState + inboundCh chan tcpSegment + curSeg *tcpSegment + curSegOffset int + readClosed bool // remote has closed + readChClosed bool + writeClosed bool + closed bool + readDeadline time.Time + writeDeadline time.Time + + nextSeq uint32 + pendingAcks map[uint32]chan struct{} + + // client connect flow + establishedCh chan struct{} + establishErr error + + // server side: notify listener once established + onEstablished func(*TCPConn) +} + +type tcpSegment struct { + seq uint32 + data []byte +} + +const tcpInboundQueueSize = 10 + +var _ transport.TCPConn = &TCPConn{} + +func newTCPConn(locAddr, remAddr *net.TCPAddr, obs connObserver, onEstablished func(*TCPConn)) (*TCPConn, error) { + if obs == nil { + return nil, errObsCannotBeNil + } + + conn := &TCPConn{ + locAddr: locAddr, + remAddr: remAddr, + obs: obs, + state: tcpStateInit, + inboundCh: make(chan tcpSegment, tcpInboundQueueSize), + establishedCh: make(chan struct{}), + onEstablished: onEstablished, + readClosed: false, + writeClosed: false, + closed: false, + nextSeq: 0, + pendingAcks: map[uint32]chan struct{}{}, + } + + return conn, nil +} + +func (c *TCPConn) startClientHandshake() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + + return errUseClosedNetworkConn + } + if c.locAddr.IP == nil { + c.mu.Unlock() + + return errLocAddr + } + + c.state = tcpStateSynSent + c.mu.Unlock() + + src := &net.TCPAddr{IP: c.locAddr.IP, Port: c.locAddr.Port} + dst := &net.TCPAddr{IP: c.remAddr.IP, Port: c.remAddr.Port} + + syn := newChunkTCP(src, dst, tcpSYN) + + return c.obs.write(syn) +} + +func (c *TCPConn) waitEstablished() error { + <-c.establishedCh + c.mu.Lock() + defer c.mu.Unlock() + + return c.establishErr +} + +func (c *TCPConn) onInboundChunk(chunk *chunkTCP) { // nolint:cyclop,gocognit + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return + } + + // RST aborts connection immediately + if chunk.flags&tcpRST != 0 { + c.establishErr = &net.OpError{Op: "dial", Net: tcp, Addr: c.remAddr, Err: errConnectionReset} + c.closed = true + c.state = tcpStateClosed + c.readClosed = true + c.closeReadChLocked() + c.closePendingAcksLocked() + select { + case <-c.establishedCh: + default: + close(c.establishedCh) + } + + return + } + + // handshake + if c.state == tcpStateSynSent { + if chunk.flags&(tcpSYN|tcpACK) == (tcpSYN | tcpACK) { + c.state = tcpStateEstablished + // reply ACK + src := &net.TCPAddr{IP: c.locAddr.IP, Port: c.locAddr.Port} + dst := &net.TCPAddr{IP: c.remAddr.IP, Port: c.remAddr.Port} + ack := newChunkTCP(src, dst, tcpACK) + go func() { _ = c.obs.write(ack) }() + select { + case <-c.establishedCh: + default: + close(c.establishedCh) + } + + return + } + } + + if c.state == tcpStateSynReceived { + if chunk.flags&tcpACK != 0 && chunk.flags&tcpSYN == 0 { + c.state = tcpStateEstablished + cb := c.onEstablished + if cb != nil { + go cb(c) + } + // Do not return here; the first ACK may also carry data (PSH). + } + } + + if chunk.flags&tcpFIN != 0 { + c.readClosed = true + select { + case <-c.establishedCh: + default: + // if the other side closed before connect completed + c.establishErr = io.EOF + close(c.establishedCh) + } + c.closeReadChLocked() + + return + } + + // Data ACK + if chunk.flags&tcpACK != 0 && chunk.ackNum != 0 { + if ch, ok := c.pendingAcks[chunk.ackNum]; ok { + delete(c.pendingAcks, chunk.ackNum) + close(ch) + } + // ACK may accompany other flags (e.g. PSH), so don't return. + } + + if chunk.flags&tcpPSH != 0 && len(chunk.userData) > 0 { + payload := make([]byte, len(chunk.userData)) + copy(payload, chunk.userData) + if !c.readChClosed { + seg := tcpSegment{seq: chunk.seqNum, data: payload} + select { + case c.inboundCh <- seg: + default: + // drop if the receive queue is full + } + } + + return + } +} + +func (c *TCPConn) closeReadChLocked() { + if c.readChClosed { + return + } + c.readChClosed = true + close(c.inboundCh) +} + +func (c *TCPConn) closePendingAcksLocked() { + for _, ch := range c.pendingAcks { + close(ch) + } + clear(c.pendingAcks) +} + +// Read reads data from the connection. +func (c *TCPConn) Read(b []byte) (int, error) { // nolint:gocognit,cyclop + for { + var ack *chunkTCP + c.mu.Lock() + if c.closed { + c.mu.Unlock() + + return 0, &net.OpError{Op: "read", Net: tcp, Addr: c.locAddr, Err: errUseClosedNetworkConn} + } + // Serve current segment if present. + if c.curSeg != nil { + remaining := c.curSeg.data[c.curSegOffset:] + n := copy(b, remaining) + c.curSegOffset += n + if c.curSegOffset >= len(c.curSeg.data) { + // ACK after the segment has been read. + src := &net.TCPAddr{IP: c.locAddr.IP, Port: c.locAddr.Port} + dst := &net.TCPAddr{IP: c.remAddr.IP, Port: c.remAddr.Port} + ack = newChunkTCP(src, dst, tcpACK) + ack.ackNum = c.curSeg.seq + c.curSeg = nil + c.curSegOffset = 0 + } + c.mu.Unlock() + if ack != nil { + _ = c.obs.write(ack) + } + + return n, nil + } + + inboundCh := c.inboundCh + deadline := c.readDeadline + c.mu.Unlock() + + // Wait for the next segment. + if !deadline.IsZero() { // nolint:nestif + until := time.Until(deadline) + if until <= 0 { + return 0, &net.OpError{Op: "read", Net: tcp, Addr: c.locAddr, Err: newTimeoutError("i/o timeout")} + } + timer := time.NewTimer(until) + select { + case seg, ok := <-inboundCh: + if !timer.Stop() { + <-timer.C + } + if !ok { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, &net.OpError{Op: "read", Net: tcp, Addr: c.locAddr, Err: errUseClosedNetworkConn} + } + + return 0, io.EOF + } + c.mu.Lock() + c.curSeg = &seg + c.curSegOffset = 0 + c.mu.Unlock() + + continue + case <-timer.C: + return 0, &net.OpError{Op: "read", Net: tcp, Addr: c.locAddr, Err: newTimeoutError("i/o timeout")} + } + } + + seg, ok := <-inboundCh + if !ok { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return 0, &net.OpError{Op: "read", Net: tcp, Addr: c.locAddr, Err: errUseClosedNetworkConn} + } + + return 0, io.EOF + } + c.mu.Lock() + c.curSeg = &seg + c.curSegOffset = 0 + c.mu.Unlock() + } +} + +// Write writes data to the connection. +func (c *TCPConn) Write(b []byte) (int, error) { // nolint:cyclop + c.mu.Lock() + if c.closed { + c.mu.Unlock() + + return 0, &net.OpError{Op: "write", Net: tcp, Addr: c.locAddr, Err: errUseClosedNetworkConn} + } + if c.writeClosed { + c.mu.Unlock() + + return 0, io.ErrClosedPipe + } + if c.state != tcpStateEstablished { + c.mu.Unlock() + + return 0, errConnectionNotEstablished + } + + seq := c.nextSeq + 1 + c.nextSeq = seq + ackCh := make(chan struct{}) + c.pendingAcks[seq] = ackCh + deadline := c.writeDeadline + + payload := make([]byte, len(b)) + copy(payload, b) + src := &net.TCPAddr{IP: c.locAddr.IP, Port: c.locAddr.Port} + dst := &net.TCPAddr{IP: c.remAddr.IP, Port: c.remAddr.Port} + chunk := newChunkTCP(src, dst, tcpPSH|tcpACK) + chunk.userData = payload + chunk.seqNum = seq + c.mu.Unlock() + + if err := c.obs.write(chunk); err != nil { + c.mu.Lock() + if ch, ok := c.pendingAcks[seq]; ok { + delete(c.pendingAcks, seq) + close(ch) + } + c.mu.Unlock() + + return 0, err + } + + if !deadline.IsZero() { + until := time.Until(deadline) + if until <= 0 { + return 0, &net.OpError{Op: "write", Net: tcp, Addr: c.locAddr, Err: newTimeoutError("i/o timeout")} + } + timer := time.NewTimer(until) + select { + case <-ackCh: + if !timer.Stop() { + <-timer.C + } + c.mu.Lock() + closed := c.closed + c.mu.Unlock() + if closed { + return 0, &net.OpError{Op: "write", Net: tcp, Addr: c.locAddr, Err: errUseClosedNetworkConn} + } + + return len(b), nil + case <-timer.C: + c.mu.Lock() + if ch, ok := c.pendingAcks[seq]; ok { + delete(c.pendingAcks, seq) + close(ch) + } + c.mu.Unlock() + + return 0, &net.OpError{Op: "write", Net: tcp, Addr: c.locAddr, Err: newTimeoutError("i/o timeout")} + } + } + + <-ackCh + c.mu.Lock() + closed := c.closed + c.mu.Unlock() + if closed { + return 0, &net.OpError{Op: "write", Net: tcp, Addr: c.locAddr, Err: errUseClosedNetworkConn} + } + + return len(b), nil +} + +// Close closes the connection. +func (c *TCPConn) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + + return errAlreadyClosed + } + c.closed = true + c.state = tcpStateClosed + c.readClosed = true + c.writeClosed = true + c.closeReadChLocked() + c.closePendingAcksLocked() + c.mu.Unlock() + + // Best-effort FIN + _ = c.CloseWrite() + if n, ok := c.obs.(*Net); ok { + _ = n.tcpConns.deleteConn(c) + } + + return nil +} + +// CloseRead closes the read side of the connection. +func (c *TCPConn) CloseRead() error { + c.mu.Lock() + defer c.mu.Unlock() + c.readClosed = true + c.closeReadChLocked() + + return nil +} + +// CloseWrite closes the write side of the connection. +func (c *TCPConn) CloseWrite() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + + return errUseClosedNetworkConn + } + if c.writeClosed { + c.mu.Unlock() + + return nil + } + c.writeClosed = true + src := &net.TCPAddr{IP: c.locAddr.IP, Port: c.locAddr.Port} + dst := &net.TCPAddr{IP: c.remAddr.IP, Port: c.remAddr.Port} + fin := newChunkTCP(src, dst, tcpFIN|tcpACK) + c.mu.Unlock() + + return c.obs.write(fin) +} + +// LocalAddr returns the local network address. +func (c *TCPConn) LocalAddr() net.Addr { + return c.locAddr +} + +// RemoteAddr returns the remote network address. +func (c *TCPConn) RemoteAddr() net.Addr { + return c.remAddr +} + +// SetDeadline sets the read and write deadlines associated with the connection. +func (c *TCPConn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + + return c.SetWriteDeadline(t) +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *TCPConn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + c.readDeadline = t + c.mu.Unlock() + + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *TCPConn) SetWriteDeadline(t time.Time) error { + c.mu.Lock() + c.writeDeadline = t + c.mu.Unlock() + + return nil +} + +// ReadFrom reads data from r and writes it to the connection. +func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { + return io.Copy(c, r) +} + +// SetLinger sets the behavior of Close method on a connection with pending data. +func (c *TCPConn) SetLinger(int) error { + return transport.ErrNotSupported +} + +// SetKeepAlive enables or disables the keep-alive functionality for this connection. +func (c *TCPConn) SetKeepAlive(bool) error { + return transport.ErrNotSupported +} + +// SetKeepAlivePeriod sets the period between keep-alive messages for this connection. +func (c *TCPConn) SetKeepAlivePeriod(time.Duration) error { + return transport.ErrNotSupported +} + +// SetNoDelay enables or disables the Nagle's algorithm for this connection. +func (c *TCPConn) SetNoDelay(bool) error { + return transport.ErrNotSupported +} + +// SetWriteBuffer sets the size of the operating system's transmit buffer associated +// with the connection. +func (c *TCPConn) SetWriteBuffer(int) error { + return transport.ErrNotSupported +} + +// SetReadBuffer sets the size of the operating system's receive buffer associated +// with the connection. +func (c *TCPConn) SetReadBuffer(int) error { + return transport.ErrNotSupported +} diff --git a/vnet/tcp_conn_test.go b/vnet/tcp_conn_test.go new file mode 100644 index 00000000..a25a2a7b --- /dev/null +++ b/vnet/tcp_conn_test.go @@ -0,0 +1,606 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "bytes" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/pion/logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var errFailedToConvertToChunkTCP = errors.New("failed to convert chunk to chunkTCP") + +func newAckingEchoTCPObserver(connPtr **TCPConn) *dummyObserver { + return &dummyObserver{ + onWrite: func(c Chunk) error { + conn := *connPtr + if conn == nil { + return errors.New("tcp conn is nil") // nolint:err113 + } + + tc, ok := c.(*chunkTCP) + if !ok { + return errFailedToConvertToChunkTCP + } + + // Immediately ACK the sent segment as if the remote read it. + if tc.flags&tcpPSH != 0 && tc.seqNum != 0 { + dstAddr := tc.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + srcAddr := tc.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert + ack := newChunkTCP(dstAddr, srcAddr, tcpACK) + ack.ackNum = tc.seqNum + conn.onInboundChunk(ack) + } + + // Echo back payload as if it came from the remote. + dstAddr := tc.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + srcAddr := tc.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert + echo := newChunkTCP(dstAddr, srcAddr, tcpPSH|tcpACK) + echo.userData = make([]byte, len(tc.userData)) + copy(echo.userData, tc.userData) + conn.onInboundChunk(echo) + + return nil + }, + onOnClosed: func(net.Addr) {}, + } +} + +func TestTCPConn(t *testing.T) { //nolint:cyclop,maintidx,gocyclo + log := logging.NewDefaultLoggerFactory().NewLogger("test") + + t.Run("ReadFrom Read", func(t *testing.T) { + var conn *TCPConn + data := []byte("Hello") + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := newAckingEchoTCPObserver(&conn) + + var err error + conn, err = newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + rcvdCh := make(chan struct{}) + doneCh := make(chan struct{}) + + go func() { + buf := make([]byte, 1500) + + for { + n, err2 := conn.Read(buf) + if err2 != nil { + log.Debug("conn closed. exiting the read loop") + + break + } + log.Debug("read data") + assert.Equal(t, len(data), n, "should match") + assert.Equal(t, string(data), string(buf[:n]), "should match") + rcvdCh <- struct{}{} + } + + close(doneCh) + }() + + n, err := conn.ReadFrom(bytes.NewReader(data)) + if !assert.NoError(t, err, "should succeed") { + return + } + assert.Equal(t, int64(len(data)), n, "should match") + + loop: + for { + select { + case <-rcvdCh: + log.Debug("closing conn..") + err2 := conn.Close() + assert.Nil(t, err2, "should succeed") + case <-doneCh: + break loop + } + } + }) + + t.Run("Write Read", func(t *testing.T) { + var conn *TCPConn + var err error + data := []byte("Hello") + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := newAckingEchoTCPObserver(&conn) + + conn, err = newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + rcvdCh := make(chan struct{}) + doneCh := make(chan struct{}) + + go func() { + buf := make([]byte, 1500) + + for { + n, err2 := conn.Read(buf) + if err2 != nil { + log.Debug("conn closed. exiting the read loop") + + break + } + log.Debug("read data") + assert.Equal(t, len(data), n, "should match") + assert.Equal(t, string(data), string(buf[:n]), "should match") + rcvdCh <- struct{}{} + } + + close(doneCh) + }() + + var n int + n, err = conn.Write(data) + if !assert.Nil(t, err, "should succeed") { + return + } + assert.Equal(t, len(data), n, "should match") + + loop: + for { + select { + case <-rcvdCh: + log.Debug("closing conn..") + err = conn.Close() + assert.Nil(t, err, "should succeed") + case <-doneCh: + break loop + } + } + }) + + deadlineTest := func(t *testing.T, readOnly bool) { + t.Helper() + + var conn *TCPConn + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := &dummyObserver{ + onWrite: func(Chunk) error { return nil }, + onOnClosed: func(net.Addr) {}, + } + + var err error + conn, err = newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + doneCh := make(chan struct{}) + + if readOnly { + err = conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + } else { + err = conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) + } + assert.Nil(t, err, "should succeed") + + go func() { + buf := make([]byte, 1500) + _, err2 := conn.Read(buf) + assert.NotNil(t, err2, "should return error") + var ne *net.OpError + if errors.As(err2, &ne) { + assert.True(t, ne.Timeout(), "should be a timeout") + } else { + assert.True(t, false, "should be an net.OpError") + } + + assert.Nil(t, conn.Close(), "should succeed") + close(doneCh) + }() + + <-doneCh + } + + t.Run("SetReadDeadline", func(t *testing.T) { + deadlineTest(t, true) + }) + + t.Run("SetDeadline", func(t *testing.T) { + deadlineTest(t, false) + }) + + t.Run("SetWriteDeadline", func(t *testing.T) { + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := &dummyObserver{ + onWrite: func(Chunk) error { return nil }, + onOnClosed: func(net.Addr) {}, + } + + conn, err := newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + err = conn.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)) + assert.NoError(t, err, "should succeed") + + _, err = conn.Write([]byte("blocked")) + assert.Error(t, err, "should timeout") + var ne *net.OpError + if errors.As(err, &ne) { + assert.True(t, ne.Timeout(), "should be a timeout") + } else { + assert.True(t, false, "should be a net.OpError") + } + + assert.NoError(t, conn.Close(), "should succeed") + }) + + t.Run("Write blocks until peer reads", func(t *testing.T) { + msg := []byte("Hello") + addrA := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + addrB := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + var connA *TCPConn + var connB *TCPConn + + obsA := &dummyObserver{ + onWrite: func(c Chunk) error { + tc, ok := c.(*chunkTCP) + if !ok { + return errFailedToConvertToChunkTCP + } + // Deliver to peer. + connB.onInboundChunk(tc.Clone().(*chunkTCP)) //nolint:forcetypeassert + + return nil + }, + onOnClosed: func(net.Addr) {}, + } + + obsB := &dummyObserver{ + onWrite: func(c Chunk) error { + tc, ok := c.(*chunkTCP) + if !ok { + return errFailedToConvertToChunkTCP + } + // Deliver to peer. + connA.onInboundChunk(tc.Clone().(*chunkTCP)) //nolint:forcetypeassert + + return nil + }, + onOnClosed: func(net.Addr) {}, + } + + var err error + connA, err = newTCPConn(addrA, addrB, obsA, nil) + assert.NoError(t, err, "should succeed") + connB, err = newTCPConn(addrB, addrA, obsB, nil) + assert.NoError(t, err, "should succeed") + + connA.mu.Lock() + connA.state = tcpStateEstablished + connA.mu.Unlock() + connB.mu.Lock() + connB.state = tcpStateEstablished + connB.mu.Unlock() + + writeDone := make(chan error, 1) + go func() { + _, err2 := connA.Write(msg) + writeDone <- err2 + }() + + // Should still be blocked (no read => no ACK). + select { + case err2 := <-writeDone: + assert.Fail(t, "Write returned before peer read", "%v", err2) + + return + case <-time.After(200 * time.Millisecond): + } + + _ = connB.SetReadDeadline(time.Now().Add(2 * time.Second)) + buf := make([]byte, len(msg)) + _, err = io.ReadFull(connB, buf) + assert.NoError(t, err, "should succeed") + assert.Equal(t, msg, buf, "should match") + + select { + case err2 := <-writeDone: + assert.NoError(t, err2, "should succeed") + case <-time.After(2 * time.Second): + assert.Fail(t, "Write did not unblock after peer read") + + return + } + + assert.NoError(t, connA.Close(), "should succeed") + assert.NoError(t, connB.Close(), "should succeed") + }) + + t.Run("RST", func(t *testing.T) { + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := &dummyObserver{ + onWrite: func(Chunk) error { return nil }, + onOnClosed: func(net.Addr) {}, + } + + conn, err := newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + writeDone := make(chan error, 1) + go func() { + _, err2 := conn.Write([]byte("data")) + writeDone <- err2 + }() + + select { + case err2 := <-writeDone: + assert.Fail(t, "Write returned before RST", "%v", err2) + + return + case <-time.After(100 * time.Millisecond): + } + + rst := newChunkTCP(dstAddr, srcAddr, tcpRST) + conn.onInboundChunk(rst) + + select { + case err2 := <-writeDone: + assert.Error(t, err2, "should error") + var ne *net.OpError + if errors.As(err2, &ne) { + assert.Equal(t, "write", ne.Op, "should match") + assert.Equal(t, errUseClosedNetworkConn, ne.Err, "should match") + } else { + assert.True(t, false, "should be a net.OpError") + } + case <-time.After(2 * time.Second): + assert.Fail(t, "Write did not unblock after RST") + + return + } + + buf := make([]byte, 10) + _, err = conn.Read(buf) + assert.Error(t, err, "should error") + _, err = conn.Write([]byte("x")) + assert.Error(t, err, "should error") + }) + + t.Run("ReadClosed (CloseRead)", func(t *testing.T) { + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + var conn *TCPConn + obs := &dummyObserver{ + onWrite: func(c Chunk) error { + tc, ok := c.(*chunkTCP) + if !ok { + return errFailedToConvertToChunkTCP + } + // ACK writes so Write doesn't block. + if tc.flags&tcpPSH != 0 && tc.seqNum != 0 { + dstAddr := tc.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + srcAddr := tc.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert + ack := newChunkTCP(dstAddr, srcAddr, tcpACK) + ack.ackNum = tc.seqNum + conn.onInboundChunk(ack) + } + + return nil + }, + onOnClosed: func(net.Addr) {}, + } + + var err error + conn, err = newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + assert.NoError(t, conn.CloseRead(), "should succeed") + + buf := make([]byte, 10) + _, err = conn.Read(buf) + assert.Equal(t, io.EOF, err, "should EOF") + + // Write side still usable until closed. + _, err = conn.Write([]byte("ok")) + assert.NoError(t, err, "should succeed") + assert.NoError(t, conn.Close(), "should succeed") + }) + + t.Run("ReadClosed (remote FIN)", func(t *testing.T) { + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := &dummyObserver{ + onWrite: func(Chunk) error { return nil }, + onOnClosed: func(net.Addr) {}, + } + + conn, err := newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + fin := newChunkTCP(dstAddr, srcAddr, tcpFIN|tcpACK) + conn.onInboundChunk(fin) + + buf := make([]byte, 10) + _, err = conn.Read(buf) + assert.Equal(t, io.EOF, err, "should EOF") + assert.NoError(t, conn.Close(), "should succeed") + }) + + t.Run("WriteClosed (CloseWrite)", func(t *testing.T) { + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := &dummyObserver{ + onWrite: func(Chunk) error { return nil }, + onOnClosed: func(net.Addr) {}, + } + + conn, err := newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + assert.NoError(t, conn.CloseWrite(), "should succeed") + _, err = conn.Write([]byte("nope")) + assert.Equal(t, io.ErrClosedPipe, err, "should match") + assert.NoError(t, conn.Close(), "should succeed") + }) + + t.Run("Inbound during close", func(t *testing.T) { + var nClosed int32 + var conn *TCPConn + srcAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + dstAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} + + obs := &dummyObserver{ + onWrite: func(Chunk) error { return nil }, + onOnClosed: func(net.Addr) { + atomic.AddInt32(&nClosed, 1) + }, + } + + var err error + conn, err = newTCPConn(srcAddr, dstAddr, obs, nil) + assert.NoError(t, err, "should succeed") + + conn.mu.Lock() + conn.state = tcpStateEstablished + conn.mu.Unlock() + + fin := newChunkTCP(dstAddr, srcAddr, tcpFIN|tcpACK) + psh := newChunkTCP(dstAddr, srcAddr, tcpPSH|tcpACK) + psh.userData = []byte("x") + + for i := 0; i < 1000; i++ { // nolint:staticcheck // (false positive detection) + chDone := make(chan struct{}) + go func() { + time.Sleep(20 * time.Millisecond) + assert.NoError(t, conn.Close()) + close(chDone) + }() + tick := time.NewTicker(10 * time.Millisecond) + for { + defer tick.Stop() + select { + case <-chDone: + // TCPConn doesn't currently notify the observer via onClosed. + assert.Equal(t, int32(0), atomic.LoadInt32(&nClosed), "should not invoke onClosed") + + return + case <-tick.C: + conn.onInboundChunk(psh) + conn.onInboundChunk(fin) + } + } + } + }) +} + +func TestVNetTCPDialListen(t *testing.T) { + loggerFactory := logging.NewDefaultLoggerFactory() + + router, err := NewRouter(&RouterConfig{ + CIDR: "192.0.2.0/24", + LoggerFactory: loggerFactory, + }) + require.NoError(t, err) + require.NoError(t, router.Start()) + defer func() { _ = router.Stop() }() + + serverNet, err := NewNet(&NetConfig{}) + require.NoError(t, err) + clientNet, err := NewNet(&NetConfig{}) + require.NoError(t, err) + + require.NoError(t, router.AddNet(serverNet)) + require.NoError(t, router.AddNet(clientNet)) + + // Bind listener to server's assigned eth0 address so clients can route to it. + eth0, err := serverNet.InterfaceByName("eth0") + require.NoError(t, err) + addrs, err := eth0.Addrs() + require.NoError(t, err) + require.NotEmpty(t, addrs) + serverIP := addrs[0].(*net.IPNet).IP //nolint:forcetypeassert + + ln, err := serverNet.ListenTCP(tcp4, &net.TCPAddr{IP: serverIP, Port: 0}) + require.NoError(t, err) + defer func() { _ = ln.Close() }() + + srvAddr := ln.Addr().(*net.TCPAddr) //nolint:forcetypeassert + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + c, err2 := ln.AcceptTCP() + require.NoError(t, err2) + defer func() { _ = c.Close() }() + + buf := make([]byte, 5) + _, err2 = io.ReadFull(c, buf) + require.NoError(t, err2) + require.Equal(t, []byte("hello"), buf) + + _, err2 = c.Write([]byte("world")) + require.NoError(t, err2) + }() + + conn, err := clientNet.DialTCP(tcp4, nil, &net.TCPAddr{IP: srvAddr.IP, Port: srvAddr.Port}) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + _, err = conn.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 5) + _, err = io.ReadFull(conn, buf) + require.NoError(t, err) + require.Equal(t, []byte("world"), buf) + + select { + case <-serverDone: + case <-time.After(2 * time.Second): + require.FailNow(t, "server did not finish") + } +} diff --git a/vnet/tcp_listener.go b/vnet/tcp_listener.go new file mode 100644 index 00000000..36605bce --- /dev/null +++ b/vnet/tcp_listener.go @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "math" + "net" + "sync" + "time" + + "github.com/pion/transport/v4" +) + +// TCPListener implements transport.TCPListener. +type TCPListener struct { + locAddr *net.TCPAddr + obs *Net + + acceptCh chan *TCPConn + mu sync.Mutex + closed bool + timer *time.Timer +} + +var _ transport.TCPListener = &TCPListener{} + +func newTCPListener(locAddr *net.TCPAddr, obs *Net) (*TCPListener, error) { + if obs == nil { + return nil, errObsCannotBeNil + } + + return &TCPListener{ + locAddr: locAddr, + obs: obs, + acceptCh: make(chan *TCPConn, 64), + timer: time.NewTimer(time.Duration(math.MaxInt64)), + }, nil +} + +func (l *TCPListener) onInboundSYN(tcp *chunkTCP) { + l.mu.Lock() + if l.closed { + l.mu.Unlock() + + return + } + l.mu.Unlock() + + dst := tcp.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + src := tcp.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert + + // If listener is on 0.0.0.0, bind accepted conn to the destination IP. + loc := &net.TCPAddr{IP: dst.IP, Port: l.locAddr.Port} + rem := &net.TCPAddr{IP: src.IP, Port: src.Port} + + conn, err := newTCPConn(loc, rem, l.obs, func(c *TCPConn) { + l.mu.Lock() + defer l.mu.Unlock() + if l.closed { + _ = c.Close() + + return + } + l.acceptCh <- c + }) + if err != nil { + return + } + + conn.mu.Lock() + conn.state = tcpStateSynReceived + conn.mu.Unlock() + + // Register early so ACK/data finds the conn. + _ = l.obs.tcpConns.insert(conn) + + // Send SYN-ACK + synAck := newChunkTCP(loc, rem, tcpSYN|tcpACK) + _ = l.obs.write(synAck) +} + +// Accept waits for and returns the next connection to the listener. +func (l *TCPListener) Accept() (net.Conn, error) { + return l.AcceptTCP() +} + +// AcceptTCP waits for and returns the next TCP connection to the listener. +func (l *TCPListener) AcceptTCP() (transport.TCPConn, error) { + for { + l.mu.Lock() + if l.closed { + l.mu.Unlock() + + return nil, errUseClosedNetworkConn + } + l.mu.Unlock() + + select { + case c := <-l.acceptCh: + return c, nil + case <-l.timer.C: + return nil, &net.OpError{Op: "accept", Net: tcp, Addr: l.locAddr, Err: newTimeoutError("i/o timeout")} + } + } +} + +// Close closes the listener. Any blocked Accept operations will be unblocked and return errors. +func (l *TCPListener) Close() error { + l.mu.Lock() + if l.closed { + l.mu.Unlock() + + return errAlreadyClosed + } + l.closed = true + l.mu.Unlock() + + _ = l.obs.tcpListeners.delete(l.locAddr) + close(l.acceptCh) + + return nil +} + +// Addr returns the listener's network address. +func (l *TCPListener) Addr() net.Addr { + return l.locAddr +} + +// SetDeadline sets the deadline for future Accept calls. +func (l *TCPListener) SetDeadline(t time.Time) error { + var d time.Duration + if t.IsZero() { + d = time.Duration(math.MaxInt64) + } else { + d = time.Until(t) + } + l.timer.Reset(d) + + return nil +} diff --git a/vnet/tcp_map.go b/vnet/tcp_map.go new file mode 100644 index 00000000..3c7d4e07 --- /dev/null +++ b/vnet/tcp_map.go @@ -0,0 +1,239 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "errors" + "net" + "sync" +) + +var ( + errNoSuchTCPConn = errors.New("no such TCPConn") + errNoSuchTCPListener = errors.New("no such TCPListener") + errTCPConnAlreadyUsed = errors.New("tcp connection tuple already in use") +) + +type tcpListenerMap struct { + portMap map[int][]*TCPListener + mutex sync.RWMutex +} + +func newTCPListenerMap() *tcpListenerMap { + return &tcpListenerMap{portMap: map[int][]*TCPListener{}} +} + +func (m *tcpListenerMap) insert(listener *TCPListener) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + addr := listener.Addr().(*net.TCPAddr) //nolint:forcetypeassert + + listeners, ok := m.portMap[addr.Port] + if ok { + if addr.IP.IsUnspecified() { + return errAddressAlreadyInUse + } + for _, existing := range listeners { + eaddr := existing.Addr().(*net.TCPAddr) //nolint:forcetypeassert + if eaddr.IP.IsUnspecified() || eaddr.IP.Equal(addr.IP) { + return errAddressAlreadyInUse + } + } + listeners = append(listeners, listener) + } else { + listeners = []*TCPListener{listener} + } + + m.portMap[addr.Port] = listeners + + return nil +} + +func (m *tcpListenerMap) find(addr *net.TCPAddr) (*TCPListener, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + listeners, ok := m.portMap[addr.Port] + if !ok { + return nil, false + } + + if addr.IP.IsUnspecified() { + if len(listeners) == 0 { + return nil, false + } + + return listeners[0], true + } + + for _, l := range listeners { + eaddr := l.Addr().(*net.TCPAddr) //nolint:forcetypeassert + if eaddr.IP.IsUnspecified() || eaddr.IP.Equal(addr.IP) { + return l, true + } + } + + return nil, false +} + +func (m *tcpListenerMap) delete(addr net.Addr) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + tcpAddr := addr.(*net.TCPAddr) //nolint:forcetypeassert + listeners, ok := m.portMap[tcpAddr.Port] + if !ok { + return errNoSuchTCPListener + } + + if tcpAddr.IP.IsUnspecified() { + delete(m.portMap, tcpAddr.Port) + + return nil + } + + newListeners := []*TCPListener{} + for _, l := range listeners { + eaddr := l.Addr().(*net.TCPAddr) //nolint:forcetypeassert + if eaddr.IP.Equal(tcpAddr.IP) { + continue + } + newListeners = append(newListeners, l) + } + + if len(newListeners) == 0 { + delete(m.portMap, tcpAddr.Port) + } else { + m.portMap[tcpAddr.Port] = newListeners + } + + return nil +} + +type tcpConnMap struct { + portMap map[int][]*TCPConn + mutex sync.RWMutex +} + +func newTCPConnMap() *tcpConnMap { + return &tcpConnMap{portMap: map[int][]*TCPConn{}} +} + +func (m *tcpConnMap) insert(conn *TCPConn) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + laddr := conn.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert + raddr := conn.RemoteAddr().(*net.TCPAddr) //nolint:forcetypeassert + + conns := m.portMap[laddr.Port] + for _, existing := range conns { + eL := existing.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert + eR := existing.RemoteAddr().(*net.TCPAddr) //nolint:forcetypeassert + if eL.IP.Equal(laddr.IP) && eR.IP.Equal(raddr.IP) && eR.Port == raddr.Port { + return errTCPConnAlreadyUsed + } + } + + m.portMap[laddr.Port] = append(conns, conn) + + return nil +} + +func (m *tcpConnMap) findByChunk(tcp *chunkTCP) (*TCPConn, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + dst := tcp.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert + src := tcp.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert + + conns, ok := m.portMap[dst.Port] + if !ok { + return nil, false + } + + for _, c := range conns { + laddr := c.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert + raddr := c.RemoteAddr().(*net.TCPAddr) //nolint:forcetypeassert + if (laddr.IP.IsUnspecified() || laddr.IP.Equal(dst.IP)) && raddr.IP.Equal(src.IP) && raddr.Port == src.Port { + return c, true + } + } + + return nil, false +} + +func (m *tcpConnMap) deleteConn(c *TCPConn) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + laddr := c.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert + conns, ok := m.portMap[laddr.Port] + if !ok { + return errNoSuchTCPConn + } + + newConns := []*TCPConn{} + for _, existing := range conns { + if existing == c { + continue + } + newConns = append(newConns, existing) + } + + if len(newConns) == 0 { + delete(m.portMap, laddr.Port) + } else { + m.portMap[laddr.Port] = newConns + } + + return nil +} + +func (m *tcpConnMap) deleteByAddr(addr net.Addr) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + tcpAddr := addr.(*net.TCPAddr) //nolint:forcetypeassert + conns, ok := m.portMap[tcpAddr.Port] + if !ok { + return errNoSuchTCPConn + } + + newConns := []*TCPConn{} + for _, c := range conns { + laddr := c.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert + if laddr.IP.Equal(tcpAddr.IP) { + continue + } + newConns = append(newConns, c) + } + + if len(newConns) == 0 { + delete(m.portMap, tcpAddr.Port) + } else { + m.portMap[tcpAddr.Port] = newConns + } + + return nil +} + +func (m *tcpConnMap) portInUse(ip net.IP, port int) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + conns, ok := m.portMap[port] + if !ok { + return false + } + for _, c := range conns { + laddr := c.LocalAddr().(*net.TCPAddr) //nolint:forcetypeassert + if laddr.IP.IsUnspecified() || laddr.IP.Equal(ip) { + return true + } + } + + return false +} diff --git a/vnet/tcp_map_test.go b/vnet/tcp_map_test.go new file mode 100644 index 00000000..5309119b --- /dev/null +++ b/vnet/tcp_map_test.go @@ -0,0 +1,295 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package vnet + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func newTestTCPListener(ip string, port int) *TCPListener { + return &TCPListener{ + locAddr: &net.TCPAddr{IP: net.ParseIP(ip), Port: port}, + } +} + +func newTestTCPConn(locIP string, locPort int, remIP string, remPort int) *TCPConn { + return &TCPConn{ + locAddr: &net.TCPAddr{IP: net.ParseIP(locIP), Port: locPort}, + remAddr: &net.TCPAddr{IP: net.ParseIP(remIP), Port: remPort}, + } +} + +func findTCPConnByTuple(m *tcpConnMap, dstIP string, dstPort int, srcIP string, srcPort int) (*TCPConn, bool) { + c := newChunkTCP( + &net.TCPAddr{IP: net.ParseIP(srcIP), Port: srcPort}, + &net.TCPAddr{IP: net.ParseIP(dstIP), Port: dstPort}, + tcpACK, + ) + + return m.findByChunk(c) +} + +func TestTCPListenerMap(t *testing.T) { + t.Run("insert a TCPListener and remove it", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("127.0.0.1", 1234) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + out, ok := listenerMap.find(l1.Addr().(*net.TCPAddr)) //nolint:forcetypeassert + assert.True(t, ok, "should succeed") + assert.Equal(t, l1, out, "should match") + assert.Equal(t, 1, len(listenerMap.portMap), "should match") + + err = listenerMap.delete(l1.Addr()) + assert.NoError(t, err, "should succeed") + assert.Empty(t, listenerMap.portMap, "should match") + + err = listenerMap.delete(l1.Addr()) + assert.Error(t, err, "should fail") + }) + + t.Run("insert a TCPListener on 0.0.0.0 and remove it", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("0.0.0.0", 1234) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + out, ok := listenerMap.find(l1.Addr().(*net.TCPAddr)) //nolint:forcetypeassert + assert.True(t, ok, "should succeed") + assert.Equal(t, l1, out, "should match") + assert.Equal(t, 1, len(listenerMap.portMap), "should match") + + err = listenerMap.delete(l1.Addr()) + assert.NoError(t, err, "should succeed") + + err = listenerMap.delete(l1.Addr()) + assert.Error(t, err, "should fail") + }) + + t.Run("find TCPListener on 0.0.0.0 by specified IP", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("0.0.0.0", 1234) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + out, ok := listenerMap.find(&net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 1234}) + assert.True(t, ok, "should succeed") + assert.Equal(t, l1, out, "should match") + assert.Equal(t, 1, len(listenerMap.portMap), "should match") + }) + + t.Run("insert many IPs with the same port", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("10.1.2.1", 5678) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + l2 := newTestTCPListener("10.1.2.2", 5678) + err = listenerMap.insert(l2) + assert.NoError(t, err, "should succeed") + + out1, ok := listenerMap.find(&net.TCPAddr{IP: net.ParseIP("10.1.2.1"), Port: 5678}) + assert.True(t, ok, "should succeed") + assert.Equal(t, l1, out1, "should match") + + out2, ok := listenerMap.find(&net.TCPAddr{IP: net.ParseIP("10.1.2.2"), Port: 5678}) + assert.True(t, ok, "should succeed") + assert.Equal(t, l2, out2, "should match") + + assert.Equal(t, 1, len(listenerMap.portMap), "should match") + }) + + t.Run("already in-use when inserting 0.0.0.0", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("10.1.2.1", 5678) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + l2 := newTestTCPListener("0.0.0.0", 5678) + err = listenerMap.insert(l2) + assert.Error(t, err, "should fail") + }) + + t.Run("already in-use when inserting a specified IP", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("0.0.0.0", 5678) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + l2 := newTestTCPListener("192.168.0.1", 5678) + err = listenerMap.insert(l2) + assert.Error(t, err, "should fail") + }) + + t.Run("already in-use when inserting the same specified IP", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("192.168.0.1", 5678) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + l2 := newTestTCPListener("192.168.0.1", 5678) + err = listenerMap.insert(l2) + assert.Error(t, err, "should fail") + }) + + t.Run("find failure 1", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("192.168.0.1", 5678) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + _, ok := listenerMap.find(&net.TCPAddr{IP: net.ParseIP("192.168.0.2"), Port: 5678}) + assert.False(t, ok, "should fail") + }) + + t.Run("find failure 2", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("192.168.0.1", 5678) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + _, ok := listenerMap.find(&net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 1234}) + assert.False(t, ok, "should fail") + }) + + t.Run("insert two TCPListeners on the same port, then remove them", func(t *testing.T) { + listenerMap := newTCPListenerMap() + + l1 := newTestTCPListener("192.168.0.1", 5678) + err := listenerMap.insert(l1) + assert.NoError(t, err, "should succeed") + + l2 := newTestTCPListener("192.168.0.2", 5678) + err = listenerMap.insert(l2) + assert.NoError(t, err, "should succeed") + + err = listenerMap.delete(l1.Addr()) + assert.NoError(t, err, "should succeed") + + err = listenerMap.delete(l2.Addr()) + assert.NoError(t, err, "should succeed") + }) +} + +func TestTCPConnMap(t *testing.T) { + t.Run("insert a TCPConn and remove it", func(t *testing.T) { + connMap := newTCPConnMap() + + c1 := newTestTCPConn("127.0.0.1", 1234, "127.0.0.1", 5678) + err := connMap.insert(c1) + assert.NoError(t, err, "should succeed") + + out, ok := findTCPConnByTuple(connMap, "127.0.0.1", 1234, "127.0.0.1", 5678) + assert.True(t, ok, "should succeed") + assert.Equal(t, c1, out, "should match") + assert.Equal(t, 1, len(connMap.portMap), "should match") + + err = connMap.deleteConn(c1) + assert.NoError(t, err, "should succeed") + assert.Empty(t, connMap.portMap, "should match") + + err = connMap.deleteConn(c1) + assert.Error(t, err, "should fail") + }) + + t.Run("insert a TCPConn on 0.0.0.0 and find it by specified IP", func(t *testing.T) { + connMap := newTCPConnMap() + + c1 := newTestTCPConn("0.0.0.0", 1234, "10.0.0.2", 5678) + err := connMap.insert(c1) + assert.NoError(t, err, "should succeed") + + out, ok := findTCPConnByTuple(connMap, "192.168.0.1", 1234, "10.0.0.2", 5678) + assert.True(t, ok, "should succeed") + assert.Equal(t, c1, out, "should match") + }) + + t.Run("insert many remote tuples with the same local port", func(t *testing.T) { + connMap := newTCPConnMap() + + c1 := newTestTCPConn("10.1.2.1", 5678, "10.1.2.100", 1111) + err := connMap.insert(c1) + assert.NoError(t, err, "should succeed") + + c2 := newTestTCPConn("10.1.2.1", 5678, "10.1.2.101", 2222) + err = connMap.insert(c2) + assert.NoError(t, err, "should succeed") + + out1, ok := findTCPConnByTuple(connMap, "10.1.2.1", 5678, "10.1.2.100", 1111) + assert.True(t, ok, "should succeed") + assert.Equal(t, c1, out1, "should match") + + out2, ok := findTCPConnByTuple(connMap, "10.1.2.1", 5678, "10.1.2.101", 2222) + assert.True(t, ok, "should succeed") + assert.Equal(t, c2, out2, "should match") + + assert.Equal(t, 1, len(connMap.portMap), "should match") + }) + + t.Run("already in-use when inserting the same tuple", func(t *testing.T) { + connMap := newTCPConnMap() + + c1 := newTestTCPConn("192.168.0.1", 5678, "192.168.0.2", 9999) + err := connMap.insert(c1) + assert.NoError(t, err, "should succeed") + + c2 := newTestTCPConn("192.168.0.1", 5678, "192.168.0.2", 9999) + err = connMap.insert(c2) + assert.Error(t, err, "should fail") + }) + + t.Run("find failure 1 (remote mismatch)", func(t *testing.T) { + connMap := newTCPConnMap() + + c1 := newTestTCPConn("192.168.0.1", 5678, "192.168.0.2", 9999) + err := connMap.insert(c1) + assert.NoError(t, err, "should succeed") + + _, ok := findTCPConnByTuple(connMap, "192.168.0.1", 5678, "192.168.0.3", 9999) + assert.False(t, ok, "should fail") + }) + + t.Run("find failure 2 (port mismatch)", func(t *testing.T) { + connMap := newTCPConnMap() + + c1 := newTestTCPConn("192.168.0.1", 5678, "192.168.0.2", 9999) + err := connMap.insert(c1) + assert.NoError(t, err, "should succeed") + + _, ok := findTCPConnByTuple(connMap, "192.168.0.1", 1234, "192.168.0.2", 9999) + assert.False(t, ok, "should fail") + }) + + t.Run("deleteByAddr removes only matching local IP", func(t *testing.T) { + connMap := newTCPConnMap() + + c1 := newTestTCPConn("192.168.0.1", 5678, "192.168.0.2", 1111) + c2 := newTestTCPConn("192.168.0.2", 5678, "192.168.0.3", 2222) + + assert.NoError(t, connMap.insert(c1)) + assert.NoError(t, connMap.insert(c2)) + + assert.NoError(t, connMap.deleteByAddr(&net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 5678})) + + _, ok1 := findTCPConnByTuple(connMap, "192.168.0.1", 5678, "192.168.0.2", 1111) + assert.False(t, ok1, "c1 should be removed") + + out2, ok2 := findTCPConnByTuple(connMap, "192.168.0.2", 5678, "192.168.0.3", 2222) + assert.True(t, ok2, "c2 should remain") + assert.Equal(t, c2, out2, "should match") + }) +} diff --git a/vnet/udpproxy_direct_test.go b/vnet/udpproxy_direct_test.go index ffc8baeb..43906c16 100644 --- a/vnet/udpproxy_direct_test.go +++ b/vnet/udpproxy_direct_test.go @@ -311,7 +311,7 @@ func TestUDPProxyDirectDeliverBadCase(t *testing.T) { //nolint:cyclop } // BadCase: Invalid address, error and ignore. - tcpAddr, err := net.ResolveTCPAddr("tcp4", "192.168.1.10:8000") + tcpAddr, err := net.ResolveTCPAddr(tcp4, "192.168.1.10:8000") if err != nil { return err } @@ -321,7 +321,7 @@ func TestUDPProxyDirectDeliverBadCase(t *testing.T) { //nolint:cyclop } // BadCase: Invalid target address, ignore. - udpAddr, err := net.ResolveUDPAddr("udp4", "10.0.0.12:5788") + udpAddr, err := net.ResolveUDPAddr(udp4, "10.0.0.12:5788") if err != nil { return err }