diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go index 02bd1f63..433857ec 100644 --- a/dhcpv4/nclient4/client.go +++ b/dhcpv4/nclient4/client.go @@ -214,7 +214,7 @@ func new(iface string, conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts . if iface == `` { return nil, ErrNoConn } - c.conn, err = NewRawUDPConn(iface, ClientPort) // broadcast + c.conn, err = NewRawUDPConn(iface, &net.UDPAddr{Port: ClientPort}, UDPBroadcast) // broadcast if err != nil { return nil, fmt.Errorf("unable to open a broadcasting socket: %w", err) } @@ -349,15 +349,9 @@ func WithLogger(newLogger Logger) ClientOpt { // srcAddr is both: // * The source address of outgoing frames. // * The address to be listened for incoming frames. -func WithUnicast(srcAddr *net.UDPAddr) ClientOpt { +func WithUnicast(iface string, srcAddr *net.UDPAddr) ClientOpt { return func(c *Client) (err error) { - if srcAddr == nil { - srcAddr = &net.UDPAddr{Port: ClientPort} - } - c.conn, err = net.ListenUDP("udp4", srcAddr) - if err != nil { - err = fmt.Errorf("unable to start listening UDP port: %w", err) - } + c.conn, err = NewRawUDPConn(iface, srcAddr, UDPUnicast) return } } diff --git a/dhcpv4/nclient4/conn_unix.go b/dhcpv4/nclient4/conn_unix.go index bece7523..fb9ff9ea 100644 --- a/dhcpv4/nclient4/conn_unix.go +++ b/dhcpv4/nclient4/conn_unix.go @@ -10,12 +10,29 @@ package nclient4 import ( "errors" + "fmt" "io" "net" + "github.com/mdlayher/arp" "github.com/mdlayher/ethernet" "github.com/mdlayher/raw" "github.com/u-root/uio/uio" + "github.com/vishvananda/netlink" +) + +// UDPConnType indicates the type of the udp conn. +type UDPConnType int + +const ( + // UDPBroadcast specifies the type of udp conn as broadcast. + // + // All the packets will be broadcasted. + UDPBroadcast UDPConnType = 0 + + // UDPUnicast specifies the type of udp conn as unicast. + // All the packets will be sent to a unicast MAC address. + UDPUnicast UDPConnType = 1 ) var ( @@ -28,13 +45,16 @@ var ( var ( // ErrUDPAddrIsRequired is an error used when a passed argument is not of type "*net.UDPAddr". ErrUDPAddrIsRequired = errors.New("must supply UDPAddr") + + // ErrHWAddrNotFound is an error used when getting MAC address failed. + ErrHWAddrNotFound = errors.New("hardware address not found") ) -// NewRawUDPConn returns a UDP connection bound to the interface and port -// given based on a raw packet socket. All packets are broadcasted. +// NewRawUDPConn returns a UDP connection bound to the interface and udp address +// given based on a raw packet socket. // // The interface can be completely unconfigured. -func NewRawUDPConn(iface string, port int) (net.PacketConn, error) { +func NewRawUDPConn(iface string, addr *net.UDPAddr, typ UDPConnType) (net.PacketConn, error) { ifc, err := net.InterfaceByName(iface) if err != nil { return nil, err @@ -43,7 +63,12 @@ func NewRawUDPConn(iface string, port int) (net.PacketConn, error) { if err != nil { return nil, err } - return NewBroadcastUDPConn(rawConn, &net.UDPAddr{Port: port}), nil + + if typ == UDPUnicast { + return NewUnicastRawUDPConn(rawConn, addr), nil + } + + return NewBroadcastUDPConn(rawConn, addr), nil } // BroadcastRawUDPConn uses a raw socket to send UDP packets to the broadcast @@ -157,3 +182,76 @@ func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { // Broadcasting is not always right, but hell, what the ARP do I know. return upc.PacketConn.WriteTo(packet, &raw.Addr{HardwareAddr: BroadcastMac}) } + +// UnicastRawUDPConn inherits from BroadcastRawUDPConn and override the WriteTo method +type UnicastRawUDPConn struct { + *BroadcastRawUDPConn +} + +// NewUnicastRawUDPConn returns a PacketConn which sending the packets to a unicast MAC address. +func NewUnicastRawUDPConn(rawPacketConn net.PacketConn, boundAddr *net.UDPAddr) net.PacketConn { + return &UnicastRawUDPConn{ + BroadcastRawUDPConn: NewBroadcastUDPConn(rawPacketConn, boundAddr).(*BroadcastRawUDPConn), + } +} + +// WriteTo implements net.PacketConn.WriteTo. +// +// WriteTo try to get the MAC address of destination IP address before +// unicast all packets at the raw socket level. +func (upc *UnicastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, ErrUDPAddrIsRequired + } + + // Using the boundAddr is not quite right here, but it works. + packet := udp4pkt(b, udpAddr, upc.boundAddr) + dstMac, err := getHwAddr(udpAddr.IP) + if err != nil { + return 0, ErrHWAddrNotFound + } + + return upc.PacketConn.WriteTo(packet, &raw.Addr{HardwareAddr: dstMac}) +} + +// getHwAddr from local arp cache. If no existing, try to get it by arp protocol. +func getHwAddr(ip net.IP) (net.HardwareAddr, error) { + neighList, err := netlink.NeighListExecute(netlink.Ndmsg{ + Family: netlink.FAMILY_V4, + State: netlink.NUD_REACHABLE, + }) + if err != nil { + return nil, err + } + + for _, neigh := range neighList { + if ip.Equal(neigh.IP) && neigh.HardwareAddr != nil { + return neigh.HardwareAddr, nil + } + } + + return arpResolve(ip) +} + +func arpResolve(dest net.IP) (net.HardwareAddr, error) { + // auto match the interface based on routes + routes, err := netlink.RouteGet(dest) + if err != nil { + return nil, err + } + if len(routes) == 0 { + return nil, fmt.Errorf("no route to %s found", dest.String()) + } + ifc, err := net.InterfaceByIndex(routes[0].LinkIndex) + if err != nil { + return nil, err + } + + c, err := arp.Dial(ifc) + if err != nil { + return nil, err + } + + return c.Resolve(dest) +} diff --git a/dhcpv4/nclient4/conn_unix_test.go b/dhcpv4/nclient4/conn_unix_test.go new file mode 100644 index 00000000..65ea0a16 --- /dev/null +++ b/dhcpv4/nclient4/conn_unix_test.go @@ -0,0 +1,84 @@ +package nclient4 + +import ( + "net" + "testing" + + "github.com/vishvananda/netlink" +) + +const ( + linkName = "neigh0" + ipStr = "10.99.0.1" + macStr = "aa:bb:cc:dd:00:01" +) + +func TestGetHwAddrFromLocalCache(t *testing.T) { + mac, err := net.ParseMAC(macStr) + if err != nil { + t.Fatal(err) + } + ip := net.ParseIP(ipStr) + + if err := addNeigh(ip, mac); err != nil { + t.Fatal(err) + } + defer func() { + if err := delNeigh(ip, mac); err != nil { + t.Fatal(err) + } + }() + + ifc, err := net.InterfaceByName(linkName) + if err != nil { + t.Fatal(err) + } + + if hw, err := getHwAddr(ifc, ip); err != nil && hw != nil && hw.String() == macStr { + t.Fatal(err) + } +} + +func TestGetHwAddrOfLoopback(t *testing.T) { + lo, err := net.InterfaceByName("lo") + if err != nil { + t.Fatalf("get the loopback interface failed, err: %s", err.Error()) + } + if _, err := getHwAddr(lo, net.ParseIP("127.0.0.1")); err != nil { + t.Fatal(err) + } +} + +func addNeigh(ip net.IP, mac net.HardwareAddr) error { + dummy := netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: linkName}} + if err := netlink.LinkAdd(&dummy); err != nil { + return err + } + newlink, err := netlink.LinkByName(dummy.Name) + if err != nil { + return err + } + dummy.Index = newlink.Attrs().Index + + return netlink.NeighAdd(&netlink.Neigh{ + LinkIndex: dummy.Index, + State: netlink.NUD_REACHABLE, + IP: ip, + HardwareAddr: mac, + }) +} + +func delNeigh(ip net.IP, mac net.HardwareAddr) error { + dummy, err := netlink.LinkByName(linkName) + if err != nil { + return err + } + + return netlink.NeighDel(&netlink.Neigh{ + LinkIndex: dummy.Attrs().Index, + State: netlink.NUD_REACHABLE, + IP: ip, + HardwareAddr: mac, + }) +} + diff --git a/go.mod b/go.mod index d53b0e47..4f6283f7 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,14 @@ require ( github.com/fanliao/go-promise v0.0.0-20141029170127-1890db352a72 github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 github.com/jsimonetti/rtnetlink v0.0.0-20201110080708-d2c240429e6c + github.com/mdlayher/arp v0.0.0-20191213142603-f72070a231fc github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 github.com/mdlayher/netlink v1.1.1 github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 github.com/smartystreets/goconvey v1.6.4 // indirect github.com/stretchr/testify v1.6.1 github.com/u-root/uio v0.0.0-20210528114334-82958018845c + github.com/vishvananda/netlink v1.1.0 golang.org/x/net v0.0.0-20201110031124-69a78807bb2b golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea ) diff --git a/go.sum b/go.sum index a89fe9d1..c8c49f84 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,9 @@ github.com/jsimonetti/rtnetlink v0.0.0-20201110080708-d2c240429e6c h1:7cpGGTQO6+ github.com/jsimonetti/rtnetlink v0.0.0-20201110080708-d2c240429e6c/go.mod h1:huN4d1phzjhlOsNIjFsw2SVRbwIHj3fJDMEU2SDPTmg= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/mdlayher/arp v0.0.0-20191213142603-f72070a231fc h1:m7rJJJeXrYCFpsxXYapkDW53wJCDmf9bsIXUg0HoeQY= +github.com/mdlayher/arp v0.0.0-20191213142603-f72070a231fc/go.mod h1:eOj1DDj3NAZ6yv+WafaKzY37MFZ58TdfIhQ+8nQbiis= +github.com/mdlayher/ethernet v0.0.0-20190313224307-5b5fc417d966/go.mod h1:5s5p/sMJ6sNsFl6uCh85lkFGV8kLuIYJCRJLavVJwvg= github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 h1:lez6TS6aAau+8wXUP3G9I3TGlmPFEq2CTxBaRqY6AGE= github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7/go.mod h1:U6ZQobyTjI/tJyq2HG+i/dfSoFUt8/aZCM+GKtmFk/Y= github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= @@ -26,6 +29,7 @@ github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqc github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY= github.com/mdlayher/netlink v1.1.1 h1:VqG+Voq9V4uZ+04vjIrcSCWDpf91B1xxbP4QBUmUJE8= github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o= +github.com/mdlayher/raw v0.0.0-20190313224157-43dbcdd7739d/go.mod h1:r1fbeITl2xL/zLbVnNHFyOzQJTgr/3fpf1lJX/cjzR8= github.com/mdlayher/raw v0.0.0-20190606142536-fef19f00fc18/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg= github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 h1:aFkJ6lx4FPip+S+Uw4aTegFMct9shDvP+79PsSxpm3w= github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg= @@ -41,9 +45,14 @@ github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/u-root/uio v0.0.0-20210528114334-82958018845c h1:BFvcl34IGnw8yvJi8hlqLFo9EshRInwWBs2M5fGWzQA= github.com/u-root/uio v0.0.0-20210528114334-82958018845c/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA= +github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= +github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= +github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= +github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190419010253-1f3472d942ba/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -59,6 +68,7 @@ golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190418153312-f0ce4c0180be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606122018-79a91cf218c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=