Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions vnet/chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ type chunkTCP struct {
destinationPort int
flags tcpFlag // control bits
userData []byte // only with PSH flag
// seq uint32 // always starts with 0
// ack uint32 // always starts with 0
seqNum uint32 // data sequence (vnet-internal)
ackNum uint32 // ACK for a seqNum (vnet-internal)
}

func newChunkTCP(srcAddr, dstAddr *net.TCPAddr, flags tcpFlag) *chunkTCP {
Expand Down Expand Up @@ -256,15 +256,20 @@ func (c *chunkTCP) Clone() Chunk {
timestamp: c.timestamp,
sourceIP: c.sourceIP,
destinationIP: c.destinationIP,
tag: c.tag,
duplicate: c.duplicate,
},
sourcePort: c.sourcePort,
destinationPort: c.destinationPort,
flags: c.flags,
userData: userData,
seqNum: c.seqNum,
ackNum: c.ackNum,
}
}

func (c *chunkTCP) Network() string {
return "tcp"
return tcp
}

func (c *chunkTCP) String() string {
Expand Down
2 changes: 1 addition & 1 deletion vnet/chunk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestChunk(t *testing.T) {
var chunk Chunk = newChunkTCP(src, dst, tcpSYN)
str := chunk.String()
log.Debugf("chunk: %s", str)
assert.Equal(t, "tcp", chunk.Network(), "should match")
assert.Equal(t, tcp, chunk.Network(), "should match")
assert.True(t, strings.Contains(str, src.Network()), "should include network type")
assert.True(t, strings.Contains(str, src.String()), "should include address")
assert.True(t, strings.Contains(str, dst.String()), "should include address")
Expand Down
2 changes: 1 addition & 1 deletion vnet/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ type timeoutError struct {
msg string
}

func newTimeoutError(msg string) error {
func newTimeoutError(msg string) error { // nolint:unparam
return &timeoutError{
msg: msg,
}
Expand Down
269 changes: 170 additions & 99 deletions vnet/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import (
)

var (
errNATRequriesMapping = errors.New("1:1 NAT requires more than one mapping")
errMismatchLengthIP = errors.New("length mismtach between mappedIPs and localIPs")
errNonUDPTranslationNotSupported = errors.New("non-udp translation is not supported yet")
errNoAssociatedLocalAddress = errors.New("no associated local address")
errNoNATBindingFound = errors.New("no NAT binding found")
errHasNoPermission = errors.New("has no permission")
errNATRequiresMapping = errors.New("1:1 NAT requires more than one mapping")
errMismatchLengthIP = errors.New("length mismatch between mappedIPs and localIPs")
errTranslationNotSupported = errors.New("translation is not supported for this protocol")
errNoAssociatedLocalAddress = errors.New("no associated local address")
errNoNATBindingFound = errors.New("no NAT binding found")
errHasNoPermission = errors.New("has no permission")
)

// EndpointDependencyType defines a type of behavioral dependendency on the
Expand Down Expand Up @@ -92,6 +92,7 @@ type networkAddressTranslator struct {
outboundMap map[string]*mapping // key: "<proto>:<local-ip>:<local-port>[:remote-ip[:remote-port]]
inboundMap map[string]*mapping // key: "<proto>:<mapped-ip>:<mapped-port>"
udpPortCounter int
tcpPortCounter int
mutex sync.RWMutex
log logging.LeveledLogger
}
Expand All @@ -107,7 +108,7 @@ func newNAT(config *natConfig) (*networkAddressTranslator, error) {
natType.MappingLifeTime = 0

if len(config.mappedIPs) == 0 {
return nil, errNATRequriesMapping
return nil, errNATRequiresMapping
}
if len(config.mappedIPs) != len(config.localIPs) {
return nil, errMismatchLengthIP
Expand Down Expand Up @@ -151,13 +152,70 @@ func (n *networkAddressTranslator) getPairedLocalIP(mappedIP net.IP) net.IP {
return nil
}

func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) { //nolint:cyclop
func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) { //nolint:cyclop,gocognit
n.mutex.Lock()
defer n.mutex.Unlock()

to := from.Clone()

if from.Network() == udp { //nolint:nestif
translateOutboundNAPT := func(proto string, portBase int, portCounter *int) (Chunk, error) {
var bound, filterKey string
switch n.natType.MappingBehavior {
case EndpointIndependent:
bound = ""
case EndpointAddrDependent:
bound = from.getDestinationIP().String()
case EndpointAddrPortDependent:
bound = from.DestinationAddr().String()
}

switch n.natType.FilteringBehavior {
case EndpointIndependent:
filterKey = ""
case EndpointAddrDependent:
filterKey = from.getDestinationIP().String()
case EndpointAddrPortDependent:
filterKey = from.DestinationAddr().String()
}

oKey := fmt.Sprintf("%s:%s:%s", proto, from.SourceAddr().String(), bound)

mapp := n.findOutboundMapping(oKey)
if mapp == nil {
mappedPort := portBase + *portCounter
(*portCounter)++

mapp = &mapping{
proto: from.SourceAddr().Network(),
local: from.SourceAddr().String(),
bound: bound,
mapped: fmt.Sprintf("%s:%d", n.mappedIPs[0].String(), mappedPort),
filters: map[string]struct{}{},
expires: time.Now().Add(n.natType.MappingLifeTime),
}

n.outboundMap[oKey] = mapp
iKey := fmt.Sprintf("%s:%s", proto, mapp.mapped)

n.log.Debugf("[%s] created a new NAT binding oKey=%s iKey=%s", n.name, oKey, iKey)

mapp.filters[filterKey] = struct{}{}
n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped)
n.inboundMap[iKey] = mapp
} else if _, ok := mapp.filters[filterKey]; !ok {
n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped)
mapp.filters[filterKey] = struct{}{}
}

if err := to.setSourceAddr(mapp.mapped); err != nil {
return nil, err
}

return to, nil
}

switch from.Network() {
case udp:
if n.natType.Mode == NATModeNAT1To1 {
// 1:1 NAT behavior
srcAddr := from.SourceAddr().(*net.UDPAddr) //nolint:forcetypeassert
Expand All @@ -172,132 +230,145 @@ func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error)
return nil, err
}
} else {
// Normal (NAPT) behavior
var bound, filterKey string
switch n.natType.MappingBehavior {
case EndpointIndependent:
bound = ""
case EndpointAddrDependent:
bound = from.getDestinationIP().String()
case EndpointAddrPortDependent:
bound = from.DestinationAddr().String()
var err error
to, err = translateOutboundNAPT("udp", 0xC000, &n.udpPortCounter)
if err != nil {
return nil, err
}
}

switch n.natType.FilteringBehavior {
case EndpointIndependent:
filterKey = ""
case EndpointAddrDependent:
filterKey = from.getDestinationIP().String()
case EndpointAddrPortDependent:
filterKey = from.DestinationAddr().String()
}
n.log.Debugf("[%s] translate outbound chunk from %s to %s", n.name, from.String(), to.String())

oKey := fmt.Sprintf("udp:%s:%s", from.SourceAddr().String(), bound)

mapp := n.findOutboundMapping(oKey)
if mapp == nil {
// Create a new mapping
mappedPort := 0xC000 + n.udpPortCounter
n.udpPortCounter++

mapp = &mapping{
proto: from.SourceAddr().Network(),
local: from.SourceAddr().String(),
bound: bound,
mapped: fmt.Sprintf("%s:%d", n.mappedIPs[0].String(), mappedPort),
filters: map[string]struct{}{},
expires: time.Now().Add(n.natType.MappingLifeTime),
}

n.outboundMap[oKey] = mapp

iKey := fmt.Sprintf("udp:%s", mapp.mapped)

n.log.Debugf("[%s] created a new NAT binding oKey=%s iKey=%s",
n.name,
oKey,
iKey)

mapp.filters[filterKey] = struct{}{}
n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped)
n.inboundMap[iKey] = mapp
} else if _, ok := mapp.filters[filterKey]; !ok {
n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped)
mapp.filters[filterKey] = struct{}{}
}
return to, nil

if err := to.setSourceAddr(mapp.mapped); err != nil {
case tcp:
if n.natType.Mode == NATModeNAT1To1 {
srcAddr := from.SourceAddr().(*net.TCPAddr) //nolint:forcetypeassert
srcIP := n.getPairedMappedIP(srcAddr.IP)
if srcIP == nil {
n.log.Debugf("[%s] drop outbound chunk %s with not route", n.name, from.String())

return nil, nil // nolint:nilnil
}
srcPort := srcAddr.Port
if err := to.setSourceAddr(fmt.Sprintf("%s:%d", srcIP.String(), srcPort)); err != nil {
return nil, err
}
} else {
var err error
to, err = translateOutboundNAPT("tcp", 0x8000, &n.tcpPortCounter)
if err != nil {
return nil, err
}
}

n.log.Debugf("[%s] translate outbound chunk from %s to %s", n.name, from.String(), to.String())

return to, nil
}

return nil, errNonUDPTranslationNotSupported
default:
return nil, errTranslationNotSupported
}
}

func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) { //nolint:cyclop
func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) { //nolint:cyclop,gocognit
n.mutex.Lock()
defer n.mutex.Unlock()

to := from.Clone()

if from.Network() == udp { //nolint:nestif
translateInboundNAT1To1 := func(dstPort int) (Chunk, error) {
dstIP := n.getPairedLocalIP(from.getDestinationIP())
if dstIP == nil {
return nil, fmt.Errorf("drop %s as %w", from.String(), errNoAssociatedLocalAddress)
}
if err := to.setDestinationAddr(fmt.Sprintf("%s:%d", dstIP, dstPort)); err != nil {
return nil, err
}

return to, nil
}

translateInboundNAPT := func(proto string) (Chunk, error) {
iKey := fmt.Sprintf("%s:%s", proto, from.DestinationAddr().String())
mapp := n.findInboundMapping(iKey)
if mapp == nil {
return nil, fmt.Errorf("drop %s as %w", from.String(), errNoNATBindingFound)
}

var filterKey string
switch n.natType.FilteringBehavior {
case EndpointIndependent:
filterKey = ""
case EndpointAddrDependent:
filterKey = from.getSourceIP().String()
case EndpointAddrPortDependent:
filterKey = from.SourceAddr().String()
}

if _, ok := mapp.filters[filterKey]; !ok {
return nil, fmt.Errorf("drop %s as the remote %s %w", from.String(), filterKey, errHasNoPermission)
}

// See RFC 4847 Section 4.3. Mapping Refresh
// a) Inbound refresh may be useful for applications with no outgoing
// UDP traffic. However, allowing inbound refresh may allow an
// external attacker or misbehaving application to keep a mapping
// alive indefinitely. This may be a security risk. Also, if the
// process is repeated with different ports, over time, it could
// use up all the ports on the NAT.

if err := to.setDestinationAddr(mapp.local); err != nil {
return nil, err
}

return to, nil
}

switch from.Network() {
case udp:
if n.natType.Mode == NATModeNAT1To1 {
// 1:1 NAT behavior
dstAddr := from.DestinationAddr().(*net.UDPAddr) //nolint:forcetypeassert
dstIP := n.getPairedLocalIP(dstAddr.IP)
if dstIP == nil {
return nil, fmt.Errorf("drop %s as %w", from.String(), errNoAssociatedLocalAddress)
}
dstPort := from.DestinationAddr().(*net.UDPAddr).Port //nolint:forcetypeassert
if err := to.setDestinationAddr(fmt.Sprintf("%s:%d", dstIP, dstPort)); err != nil {
var err error
to, err = translateInboundNAT1To1(dstAddr.Port)
if err != nil {
return nil, err
}
} else {
// Normal (NAPT) behavior
iKey := fmt.Sprintf("udp:%s", from.DestinationAddr().String())
mapping := n.findInboundMapping(iKey)
if mapping == nil {
return nil, fmt.Errorf("drop %s as %w", from.String(), errNoNATBindingFound)
}

var filterKey string
switch n.natType.FilteringBehavior {
case EndpointIndependent:
filterKey = ""
case EndpointAddrDependent:
filterKey = from.getSourceIP().String()
case EndpointAddrPortDependent:
filterKey = from.SourceAddr().String()
var err error
to, err = translateInboundNAPT(udp)
if err != nil {
return nil, err
}
}

if _, ok := mapping.filters[filterKey]; !ok {
return nil, fmt.Errorf("drop %s as the remote %s %w", from.String(), filterKey, errHasNoPermission)
}
n.log.Debugf("[%s] translate inbound chunk from %s to %s", n.name, from.String(), to.String())

// See RFC 4847 Section 4.3. Mapping Refresh
// a) Inbound refresh may be useful for applications with no outgoing
// UDP traffic. However, allowing inbound refresh may allow an
// external attacker or misbehaving application to keep a mapping
// alive indefinitely. This may be a security risk. Also, if the
// process is repeated with different ports, over time, it could
// use up all the ports on the NAT.
return to, nil

if err := to.setDestinationAddr(mapping.local); err != nil {
case tcp:
if n.natType.Mode == NATModeNAT1To1 {
dstAddr := from.DestinationAddr().(*net.TCPAddr) //nolint:forcetypeassert
var err error
to, err = translateInboundNAT1To1(dstAddr.Port)
if err != nil {
return nil, err
}
} else {
var err error
to, err = translateInboundNAPT(tcp)
if err != nil {
return nil, err
}
}

n.log.Debugf("[%s] translate inbound chunk from %s to %s", n.name, from.String(), to.String())

return to, nil
}

return nil, errNonUDPTranslationNotSupported
default:
return nil, errTranslationNotSupported
}
}

// caller must hold the mutex.
Expand Down
Loading
Loading