diff --git a/dhcpv6/client6/client.go b/dhcpv6/client6/client.go index b1f8e111..c7bd318f 100644 --- a/dhcpv6/client6/client.go +++ b/dhcpv6/client6/client.go @@ -199,7 +199,11 @@ func (c *Client) sendReceive(ifname string, packet dhcpv6.DHCPv6, expectedType d // an error if any. The modifiers will be applied to the Solicit before sending // it, see modifiers.go func (c *Client) Solicit(ifname string, modifiers ...dhcpv6.Modifier) (dhcpv6.DHCPv6, dhcpv6.DHCPv6, error) { - solicit, err := dhcpv6.NewSolicitForInterface(ifname) + iface, err := net.InterfaceByName(ifname) + if err != nil { + return nil, nil, err + } + solicit, err := dhcpv6.NewSolicit(iface.HardwareAddr) if err != nil { return nil, nil, err } diff --git a/dhcpv6/dhcpv6.go b/dhcpv6/dhcpv6.go index c8c42927..7505dfcd 100644 --- a/dhcpv6/dhcpv6.go +++ b/dhcpv6/dhcpv6.go @@ -16,6 +16,11 @@ type DHCPv6 interface { String() string Summary() string IsRelay() bool + + // GetInnerMessage returns the innermost encapsulated DHCPv6 message. + // + // If it is already a message, it will be returned. If it is a relay + // message, the encapsulated message will be recursively extracted. GetInnerMessage() (*Message, error) GetOption(code OptionCode) []Option @@ -108,11 +113,11 @@ func DecapsulateRelay(l DHCPv6) (DHCPv6, error) { } opt := l.GetOneOption(OptionRelayMsg) if opt == nil { - return nil, fmt.Errorf("No OptRelayMsg found") + return nil, fmt.Errorf("malformed Relay message: no OptRelayMsg found") } relayOpt := opt.(*OptRelayMsg) if relayOpt.RelayMessage() == nil { - return nil, fmt.Errorf("Relay message cannot be nil") + return nil, fmt.Errorf("malformed Relay message: encapsulated message is empty") } return relayOpt.RelayMessage(), nil } diff --git a/dhcpv6/dhcpv6message.go b/dhcpv6/dhcpv6message.go index 9e03fa7e..9237c1b6 100644 --- a/dhcpv6/dhcpv6message.go +++ b/dhcpv6/dhcpv6message.go @@ -70,18 +70,14 @@ func NewSolicitWithCID(duid Duid, modifiers ...Modifier) (*Message, error) { return m, nil } -// NewSolicitForInterface creates a new SOLICIT message with DUID-LLT, using the +// NewSolicit creates a new SOLICIT message with DUID-LLT, using the // given network interface's hardware address and current time -func NewSolicitForInterface(ifname string, modifiers ...Modifier) (*Message, error) { - iface, err := net.InterfaceByName(ifname) - if err != nil { - return nil, err - } +func NewSolicit(ifaceHWAddr net.HardwareAddr, modifiers ...Modifier) (*Message, error) { duid := Duid{ Type: DUID_LLT, HwType: iana.HWTypeEthernet, Time: GetTime(), - LinkLayerAddr: iface.HardwareAddr, + LinkLayerAddr: ifaceHWAddr, } return NewSolicitWithCID(duid, modifiers...) } diff --git a/dhcpv6/modifiers.go b/dhcpv6/modifiers.go index 8c75ea5a..eaa370d8 100644 --- a/dhcpv6/modifiers.go +++ b/dhcpv6/modifiers.go @@ -99,6 +99,11 @@ func WithDomainSearchList(searchlist ...string) Modifier { } } +// WithRapidCommit adds the rapid commit option to a message. +func WithRapidCommit(d DHCPv6) { + d.UpdateOption(&OptionGeneric{OptionCode: OptionRapidCommit}) +} + // WithRequestedOptions adds requested options to the packet func WithRequestedOptions(optionCodes ...OptionCode) Modifier { return func(d DHCPv6) { diff --git a/dhcpv6/nclient6/client.go b/dhcpv6/nclient6/client.go new file mode 100644 index 00000000..dc5dd330 --- /dev/null +++ b/dhcpv6/nclient6/client.go @@ -0,0 +1,371 @@ +// Copyright 2018 the u-root Authors and Andrea Barberio. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nclient6 + +import ( + "context" + "errors" + "fmt" + "log" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/insomniacslk/dhcp/dhcpv6" +) + +// Broadcast destination IP addresses as defined by RFC 3315 +var ( + AllDHCPRelayAgentsAndServers = &net.UDPAddr{ + IP: net.ParseIP("ff02::1:2"), + Port: dhcpv6.DefaultServerPort, + } + AllDHCPServers = &net.UDPAddr{ + IP: net.ParseIP("ff05::1:3"), + Port: dhcpv6.DefaultServerPort, + } +) + +const ( + maxUDPReceivedPacketSize = 8192 + maxMessageSize = 1500 +) + +var ( + // ErrNoResponse is returned when no response packet is received. + ErrNoResponse = errors.New("no matching response packet received") +) + +// pendingCh is a channel associated with a pending TransactionID. +type pendingCh struct { + // SendAndRead closes done to indicate that it wishes for no more + // messages for this particular XID. + done <-chan struct{} + + // ch is used by the receive loop to distribute DHCP messages. + ch chan<- *dhcpv6.Message +} + +// Client is a DHCPv6 client. +type Client struct { + ifaceHWAddr net.HardwareAddr + conn net.PacketConn + timeout time.Duration + retry int + + // bufferCap is the channel capacity for each TransactionID. + bufferCap int + + // serverAddr is the UDP address to send all packets to. + // + // This may be an actual broadcast address, or a unicast address. + serverAddr *net.UDPAddr + + // closed is an atomic bool set to 1 when done is closed. + closed uint32 + + // done is closed to unblock the receive loop. + done chan struct{} + + // wg protects the receiveLoop. + wg sync.WaitGroup + + pendingMu sync.Mutex + // pending stores the distribution channels for each pending + // TransactionID. receiveLoop uses this map to determine which channel + // to send a new DHCP message to. + pending map[dhcpv6.TransactionID]*pendingCh +} + +// NewIPv6UDPConn returns a UDP connection bound to both the interface and port +// given based on a IPv6 DGRAM socket. +func NewIPv6UDPConn(iface string, port int) (net.PacketConn, error) { + return net.ListenUDP("udp6", &net.UDPAddr{ + Port: port, + Zone: iface, + }) +} + +// New creates a new DHCP client that sends and receives packets on the given +// interface. +func New(ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) { + c := &Client{ + ifaceHWAddr: ifaceHWAddr, + timeout: 5 * time.Second, + retry: 3, + serverAddr: AllDHCPServers, + bufferCap: 5, + + done: make(chan struct{}), + pending: make(map[dhcpv6.TransactionID]*pendingCh), + } + + for _, opt := range opts { + opt(c) + } + + if c.conn == nil { + return nil, fmt.Errorf("require a connection") + } + + c.receiveLoop() + return c, nil +} + +// Close closes the underlying connection. +func (c *Client) Close() error { + // Make sure not to close done twice. + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return nil + } + + err := c.conn.Close() + + // Closing c.done sets off a chain reaction: + // + // Any SendAndRead unblocks trying to receive more messages, which + // means rem() gets called. + // + // rem() should be unblocking receiveLoop if it is blocked. + // + // receiveLoop should then exit gracefully. + close(c.done) + + // Wait for receiveLoop to stop. + c.wg.Wait() + + return err +} + +func isErrClosing(err error) bool { + // Unfortunately, the epoll-connection-closed error is internal to the + // net library. + return strings.Contains(err.Error(), "use of closed network connection") +} + +func (c *Client) receiveLoop() { + c.wg.Add(1) + go func() { + defer c.wg.Done() + for { + // TODO: Clients can send a "max packet size" option in their + // packets, IIRC. Choose a reasonable size and set it. + b := make([]byte, 1500) + n, _, err := c.conn.ReadFrom(b) + if err != nil { + if !isErrClosing(err) { + log.Printf("error reading from UDP connection: %v", err) + } + return + } + + msg, err := dhcpv6.MessageFromBytes(b[:n]) + if err != nil { + // Not a valid DHCP packet; keep listening. + continue + } + + c.pendingMu.Lock() + p, ok := c.pending[msg.TransactionID] + if ok { + select { + case <-p.done: + close(p.ch) + delete(c.pending, msg.TransactionID) + + // This send may block. + case p.ch <- msg: + } + } + c.pendingMu.Unlock() + } + }() +} + +// ClientOpt is a function that configures the Client. +type ClientOpt func(*Client) + +func withBufferCap(n int) ClientOpt { + return func(c *Client) { + c.bufferCap = n + } +} + +// WithTimeout configures the retransmission timeout. +// +// Default is 5 seconds. +func WithTimeout(d time.Duration) ClientOpt { + return func(c *Client) { + c.timeout = d + } +} + +// WithRetry configures the number of retransmissions to attempt. +// +// Default is 3. +func WithRetry(r int) ClientOpt { + return func(c *Client) { + c.retry = r + } +} + +// WithConn configures the packet connection to use. +func WithConn(conn net.PacketConn) ClientOpt { + return func(c *Client) { + c.conn = conn + } +} + +// WithBroadcastAddr configures the address to broadcast to. +func WithBroadcastAddr(n *net.UDPAddr) ClientOpt { + return func(c *Client) { + c.serverAddr = n + } +} + +// Matcher matches DHCP packets. +type Matcher func(*dhcpv6.Message) bool + +// IsMessageType returns a matcher that checks for the message type. +// +// If t is MessageTypeNone, all packets are matched. +func IsMessageType(t dhcpv6.MessageType) Matcher { + return func(p *dhcpv6.Message) bool { + return p.MessageType == t || t == dhcpv6.MessageTypeNone + } +} + +// Solicit sends a solicitation message and returns the first valid +// advertisement received. +func (c *Client) Solicit(ctx context.Context, modifiers ...dhcpv6.Modifier) (*dhcpv6.Message, error) { + solicit, err := dhcpv6.NewSolicit(c.ifaceHWAddr, modifiers...) + if err != nil { + return nil, err + } + msg, err := c.SendAndRead(ctx, c.serverAddr, solicit, IsMessageType(dhcpv6.MessageTypeAdvertise)) + if err != nil { + return nil, err + } + return msg, nil +} + +// Request requests an IP Assignment from peer given an advertise message. +func (c *Client) Request(ctx context.Context, advertise *dhcpv6.Message, modifiers ...dhcpv6.Modifier) (*dhcpv6.Message, error) { + request, err := dhcpv6.NewRequestFromAdvertise(advertise, modifiers...) + if err != nil { + return nil, err + } + return c.SendAndRead(ctx, c.serverAddr, request, nil) +} + +// send sends p to destination and returns a response channel. +// +// The returned function must be called after all desired responses have been +// received. +// +// Responses will be matched by transaction ID. +func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Message, func(), error) { + c.pendingMu.Lock() + if _, ok := c.pending[msg.TransactionID]; ok { + c.pendingMu.Unlock() + return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID) + } + + ch := make(chan *dhcpv6.Message, c.bufferCap) + done := make(chan struct{}) + c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch} + c.pendingMu.Unlock() + + cancel := func() { + // Why can't we just close ch here? + // + // Because receiveLoop may potentially be blocked trying to + // send on ch. We gotta unblock it first, so it'll unlock the + // lock, and then we can take the lock and remove the XID from + // the pending transaction map. + close(done) + + c.pendingMu.Lock() + if p, ok := c.pending[msg.TransactionID]; ok { + close(p.ch) + delete(c.pending, msg.TransactionID) + } + c.pendingMu.Unlock() + } + + if _, err := c.conn.WriteTo(msg.ToBytes(), dest); err != nil { + cancel() + return nil, nil, fmt.Errorf("error writing packet to connection: %v", err) + } + return ch, cancel, nil +} + +// This should never be visible to a user. +var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded") + +// SendAndRead sends a packet p to a destination dest and waits for the first +// response matching `match` as well as its Transaction ID. +// +// If match is nil, the first packet matching the Transaction ID is returned. +func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.Message, match Matcher) (*dhcpv6.Message, error) { + var response *dhcpv6.Message + err := c.retryFn(func(timeout time.Duration) error { + ch, rem, err := c.send(dest, msg) + if err != nil { + return err + } + defer rem() + + for { + select { + case <-c.done: + return ErrNoResponse + + case <-time.After(timeout): + return errDeadlineExceeded + + case <-ctx.Done(): + return ctx.Err() + + case packet := <-ch: + if match == nil || match(packet) { + response = packet + return nil + } + } + } + }) + if err == errDeadlineExceeded { + return nil, ErrNoResponse + } + if err != nil { + return nil, err + } + return response, nil +} + +func (c *Client) retryFn(fn func(timeout time.Duration) error) error { + timeout := c.timeout + + // Each retry takes the amount of timeout at worst. + for i := 0; i < c.retry || c.retry < 0; i++ { + switch err := fn(timeout); err { + case nil: + // Got it! + return nil + + case errDeadlineExceeded: + // Double timeout, then retry. + timeout *= 2 + + default: + return err + } + } + + return errDeadlineExceeded +} diff --git a/dhcpv6/nclient6/client_test.go b/dhcpv6/nclient6/client_test.go new file mode 100644 index 00000000..cba4ef8e --- /dev/null +++ b/dhcpv6/nclient6/client_test.go @@ -0,0 +1,258 @@ +// Copyright 2018 the u-root Authors and Andrea Barberio. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 + +package nclient6 + +import ( + "bytes" + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/hugelgupf/socketpair" + "github.com/insomniacslk/dhcp/dhcpv6" + "github.com/insomniacslk/dhcp/dhcpv6/server6" +) + +type handler struct { + mu sync.Mutex + received []*dhcpv6.Message + + // Each received packet can have more than one response (in theory, + // from different servers sending different Advertise, for example). + responses [][]*dhcpv6.Message +} + +func (h *handler) handle(conn net.PacketConn, peer net.Addr, msg dhcpv6.DHCPv6) { + h.mu.Lock() + defer h.mu.Unlock() + + m := msg.(*dhcpv6.Message) + + h.received = append(h.received, m) + + if len(h.responses) > 0 { + resps := h.responses[0] + // What should we send in response? + for _, resp := range resps { + conn.WriteTo(resp.ToBytes(), peer) + } + h.responses = h.responses[1:] + } +} + +func serveAndClient(ctx context.Context, responses [][]*dhcpv6.Message, opt ...ClientOpt) (*Client, net.PacketConn) { + // Fake connection between client and server. No raw sockets, no port + // weirdness. + clientRawConn, serverRawConn, err := socketpair.PacketSocketPair() + if err != nil { + panic(err) + } + + o := []ClientOpt{WithConn(clientRawConn), WithRetry(1), WithTimeout(2 * time.Second)} + o = append(o, opt...) + mc, err := New(net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...) + if err != nil { + panic(err) + } + + h := &handler{ + responses: responses, + } + s, err := server6.NewServer(nil, h.handle, server6.WithConn(serverRawConn)) + if err != nil { + panic(err) + } + go s.Serve() + + return mc, serverRawConn +} + +func ComparePacket(got *dhcpv6.Message, want *dhcpv6.Message) error { + if got == nil && got == want { + return nil + } + if (want == nil || got == nil) && (got != want) { + return fmt.Errorf("packet got %v, want %v", got, want) + } + if bytes.Compare(got.ToBytes(), want.ToBytes()) != 0 { + return fmt.Errorf("packet got %v, want %v", got, want) + } + return nil +} + +func pktsExpected(got []*dhcpv6.Message, want []*dhcpv6.Message) error { + if len(got) != len(want) { + return fmt.Errorf("got %d packets, want %d packets", len(got), len(want)) + } + + for i := range got { + if err := ComparePacket(got[i], want[i]); err != nil { + return err + } + } + return nil +} + +func newPacket(xid dhcpv6.TransactionID) *dhcpv6.Message { + p, err := dhcpv6.NewMessage() + if err != nil { + panic(fmt.Sprintf("newpacket: %v", err)) + } + p.TransactionID = xid + return p +} + +func TestSendAndReadUntil(t *testing.T) { + for _, tt := range []struct { + desc string + send *dhcpv6.Message + server []*dhcpv6.Message + + // If want is nil, we assume server contains what is wanted. + want *dhcpv6.Message + wantErr error + }{ + { + desc: "two response packets", + send: newPacket([3]byte{0x33, 0x33, 0x33}), + server: []*dhcpv6.Message{ + newPacket([3]byte{0x33, 0x33, 0x33}), + newPacket([3]byte{0x33, 0x33, 0x33}), + }, + want: newPacket([3]byte{0x33, 0x33, 0x33}), + }, + { + desc: "one response packet", + send: newPacket([3]byte{0x33, 0x33, 0x33}), + server: []*dhcpv6.Message{ + newPacket([3]byte{0x33, 0x33, 0x33}), + }, + want: newPacket([3]byte{0x33, 0x33, 0x33}), + }, + { + desc: "one response packet, one invalid XID", + send: newPacket([3]byte{0x33, 0x33, 0x33}), + server: []*dhcpv6.Message{ + newPacket([3]byte{0x77, 0x33, 0x33}), + newPacket([3]byte{0x33, 0x33, 0x33}), + }, + want: newPacket([3]byte{0x33, 0x33, 0x33}), + }, + { + desc: "discard wrong XID", + send: newPacket([3]byte{0x33, 0x33, 0x33}), + server: []*dhcpv6.Message{ + newPacket([3]byte{0, 0, 0}), + }, + want: nil, + wantErr: ErrNoResponse, + }, + { + desc: "no response, timeout", + send: newPacket([3]byte{0x33, 0x33, 0x33}), + wantErr: ErrNoResponse, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + // Both server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, _ := serveAndClient(ctx, [][]*dhcpv6.Message{tt.server}, + // Use an unbuffered channel to make sure we + // have no deadlocks. + withBufferCap(0)) + defer mc.Close() + + rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, tt.send, nil) + if err != tt.wantErr { + t.Error(err) + } + + if err := ComparePacket(rcvd, tt.want); err != nil { + t.Errorf("got unexpected packets: %v", err) + } + }) + } +} + +func TestSimpleSendAndReadDiscardGarbage(t *testing.T) { + pkt := newPacket([3]byte{0x33, 0x33, 0x33}) + + responses := []*dhcpv6.Message{ + newPacket([3]byte{0x33, 0x33, 0x33}), + } + + // Both the server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, udpConn := serveAndClient(ctx, [][]*dhcpv6.Message{responses}) + defer mc.Close() + + // Too short for valid DHCPv4 packet. + udpConn.WriteTo([]byte{0x01}, nil) + + rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, pkt, nil) + if err != nil { + t.Error(err) + } + + if err := ComparePacket(rcvd, responses[0]); err != nil { + t.Errorf("got unexpected packets: %v", err) + } +} + +func TestMultipleSendAndReadOne(t *testing.T) { + for _, tt := range []struct { + desc string + send []*dhcpv6.Message + server [][]*dhcpv6.Message + wantErr []error + }{ + { + desc: "two requests, two responses", + send: []*dhcpv6.Message{ + newPacket([3]byte{0x33, 0x33, 0x33}), + newPacket([3]byte{0x44, 0x44, 0x44}), + }, + server: [][]*dhcpv6.Message{ + []*dhcpv6.Message{ // Response for first packet. + newPacket([3]byte{0x33, 0x33, 0x33}), + }, + []*dhcpv6.Message{ // Response for second packet. + newPacket([3]byte{0x44, 0x44, 0x44}), + }, + }, + wantErr: []error{ + nil, + nil, + }, + }, + } { + // Both server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + mc, _ := serveAndClient(ctx, tt.server) + defer mc.conn.Close() + + for i, send := range tt.send { + rcvd, err := mc.SendAndRead(context.Background(), AllDHCPServers, send, nil) + + if wantErr := tt.wantErr[i]; err != wantErr { + t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr) + } + if err := pktsExpected([]*dhcpv6.Message{rcvd}, tt.server[i]); err != nil { + t.Errorf("got unexpected packets: %v", err) + } + } + } +} diff --git a/dhcpv6/server6/server.go b/dhcpv6/server6/server.go index 6a6d0b7a..f6fb8269 100644 --- a/dhcpv6/server6/server.go +++ b/dhcpv6/server6/server.go @@ -1,11 +1,8 @@ package server6 import ( - "fmt" "log" "net" - "sync" - "time" "github.com/insomniacslk/dhcp/dhcpv6" ) @@ -16,102 +13,66 @@ type Handler func(conn net.PacketConn, peer net.Addr, m dhcpv6.DHCPv6) // Server represents a DHCPv6 server object type Server struct { - conn net.PacketConn - connMutex sync.Mutex - shouldStop chan bool - Handler Handler - localAddr net.UDPAddr + conn net.PacketConn + handler Handler } -// LocalAddr returns the local address of the listening socket, or nil if not -// listening -func (s *Server) LocalAddr() net.Addr { - s.connMutex.Lock() - defer s.connMutex.Unlock() - if s.conn == nil { - return nil - } - return s.conn.LocalAddr() -} - -// ActivateAndServe starts the DHCPv6 server. The listener will run in -// background, and can be interrupted with `Server.Close`. -func (s *Server) ActivateAndServe() error { - s.connMutex.Lock() - if s.conn != nil { - // this may panic if s.conn is closed but not reset properly. For that - // you should use `Server.Close`. - s.Close() - } - conn, err := net.ListenUDP("udp6", &s.localAddr) - if err != nil { - s.connMutex.Unlock() - return err - } - s.conn = conn - s.connMutex.Unlock() - var ( - pc *net.UDPConn - ok bool - ) - if pc, ok = s.conn.(*net.UDPConn); !ok { - return fmt.Errorf("error: not an UDPConn") - } - if pc == nil { - return fmt.Errorf("ActivateAndServe: invalid nil PacketConn") - } - log.Printf("Server listening on %s", pc.LocalAddr()) +// Serve starts the DHCPv6 server. The listener will run in background, and can +// be interrupted with `Server.Close`. +func (s *Server) Serve() { + log.Printf("Server listening on %s", s.conn.LocalAddr()) log.Print("Ready to handle requests") + for { - select { - case <-s.shouldStop: - break - case <-time.After(time.Millisecond): - } - pc.SetReadDeadline(time.Now().Add(time.Second)) rbuf := make([]byte, 4096) // FIXME this is bad - n, peer, err := pc.ReadFrom(rbuf) + n, peer, err := s.conn.ReadFrom(rbuf) if err != nil { - switch err.(type) { - case net.Error: - if !err.(net.Error).Timeout() { - return err - } - // if timeout, silently skip and continue - default: - // complain and continue - log.Printf("Error reading from packet conn: %v", err) - } - continue + log.Printf("Error reading from packet conn: %v", err) + return } log.Printf("Handling request from %v", peer) - m, err := dhcpv6.FromBytes(rbuf[:n]) + + d, err := dhcpv6.FromBytes(rbuf[:n]) if err != nil { log.Printf("Error parsing DHCPv6 request: %v", err) continue } - go s.Handler(pc, peer, m) + + go s.handler(s.conn, peer, d) } } // Close sends a termination request to the server, and closes the UDP listener func (s *Server) Close() error { - s.shouldStop <- true - s.connMutex.Lock() - defer s.connMutex.Unlock() - if s.conn != nil { - ret := s.conn.Close() - s.conn = nil - return ret + return s.conn.Close() +} + +// A ServerOpt configures a Server. +type ServerOpt func(s *Server) + +// WithConn configures a server with the given connection. +func WithConn(conn net.PacketConn) ServerOpt { + return func(s *Server) { + s.conn = conn } - return nil } // NewServer initializes and returns a new Server object -func NewServer(addr net.UDPAddr, handler Handler) *Server { - return &Server{ - localAddr: addr, - Handler: handler, - shouldStop: make(chan bool, 1), +func NewServer(addr *net.UDPAddr, handler Handler, opt ...ServerOpt) (*Server, error) { + s := &Server{ + handler: handler, + } + + for _, o := range opt { + o(s) + } + + if s.conn == nil { + conn, err := net.ListenUDP("udp6", addr) + if err != nil { + return nil, err + } + s.conn = conn } + return s, nil } diff --git a/dhcpv6/server6/server_test.go b/dhcpv6/server6/server_test.go index 3d2a3652..05d62cb0 100644 --- a/dhcpv6/server6/server_test.go +++ b/dhcpv6/server6/server_test.go @@ -1,62 +1,58 @@ package server6 import ( + "context" "log" "net" "testing" - "time" "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/insomniacslk/dhcp/dhcpv6/client6" + "github.com/insomniacslk/dhcp/dhcpv6/nclient6" "github.com/insomniacslk/dhcp/interfaces" "github.com/stretchr/testify/require" ) +type fakeUnconnectedConn struct { + *net.UDPConn +} + +func (f fakeUnconnectedConn) WriteTo(b []byte, _ net.Addr) (int, error) { + return f.UDPConn.Write(b) +} + +func (f fakeUnconnectedConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := f.Read(b) + return n, nil, err +} + // utility function to set up a client and a server instance and run it in // background. The caller needs to call Server.Close() once finished. -func setUpClientAndServer(handler Handler) (*client6.Client, *Server) { - laddr := net.UDPAddr{ +func setUpClientAndServer(handler Handler) (*nclient6.Client, *Server) { + laddr := &net.UDPAddr{ IP: net.ParseIP("::1"), Port: 0, } - s := NewServer(laddr, handler) - go s.ActivateAndServe() - - c := client6.NewClient() - c.LocalAddr = &net.UDPAddr{ - IP: net.ParseIP("::1"), - } - for { - if s.LocalAddr() != nil { - break - } - time.Sleep(10 * time.Millisecond) - log.Printf("Waiting for server to run...") - } - c.RemoteAddr = &net.UDPAddr{ - IP: net.ParseIP("::1"), - Port: s.LocalAddr().(*net.UDPAddr).Port, + s, err := NewServer(laddr, handler) + if err != nil { + panic(err) } + go s.Serve() - return c, s -} + clientConn, err := net.DialUDP("udp6", &net.UDPAddr{IP: net.ParseIP("::1")}, s.conn.LocalAddr().(*net.UDPAddr)) + if err != nil { + panic(err) + } -func TestNewServer(t *testing.T) { - laddr := net.UDPAddr{ - IP: net.ParseIP("::1"), - Port: 0, + c, err := nclient6.New(net.HardwareAddr{1, 2, 3, 4, 5, 6}, + nclient6.WithConn(fakeUnconnectedConn{clientConn})) + if err != nil { + panic(err) } - handler := func(conn net.PacketConn, peer net.Addr, m dhcpv6.DHCPv6) {} - s := NewServer(laddr, handler) - defer s.Close() - require.NotNil(t, s) - require.Nil(t, s.conn) - require.Equal(t, laddr, s.localAddr) - require.NotNil(t, s.Handler) + return c, s } -func TestServerActivateAndServe(t *testing.T) { +func TestServer(t *testing.T) { handler := func(conn net.PacketConn, peer net.Addr, m dhcpv6.DHCPv6) { msg := m.(*dhcpv6.Message) adv, err := dhcpv6.NewAdvertiseFromSolicit(msg) @@ -68,6 +64,7 @@ func TestServerActivateAndServe(t *testing.T) { log.Printf("Cannot reply to client: %v", err) } } + c, s := setUpClientAndServer(handler) defer s.Close() @@ -75,6 +72,6 @@ func TestServerActivateAndServe(t *testing.T) { require.NoError(t, err) require.NotEqual(t, 0, len(ifaces)) - _, _, err = c.Solicit(ifaces[0].Name) + _, err = c.Solicit(context.Background(), dhcpv6.WithRapidCommit) require.NoError(t, err) } diff --git a/examples/packetcrafting6/main.go b/examples/packetcrafting6/main.go index 7e2fb1bf..f2c18550 100644 --- a/examples/packetcrafting6/main.go +++ b/examples/packetcrafting6/main.go @@ -15,9 +15,14 @@ func main() { // that implement the `dhcpv6.DHCPv6` interface. // Then print the wire-format representation of the packet. + iface, err := net.InterfaceByName("eth0") + if err != nil { + log.Fatal(err) + } + // Create the DHCPv6 Solicit first, using the interface "eth0" // to get the MAC address - msg, err := dhcpv6.NewSolicitForInterface("eth0") + msg, err := dhcpv6.NewSolicit(iface.HardwareAddr) if err != nil { log.Fatal(err) } diff --git a/examples/server6/main.go b/examples/server6/main.go index 174f798b..318937f2 100644 --- a/examples/server6/main.go +++ b/examples/server6/main.go @@ -14,14 +14,14 @@ func handler(conn net.PacketConn, peer net.Addr, m dhcpv6.DHCPv6) { } func main() { - laddr := net.UDPAddr{ + laddr := &net.UDPAddr{ IP: net.ParseIP("::1"), Port: dhcpv6.DefaultServerPort, } - server := server6.NewServer(laddr, handler) - - defer server.Close() - if err := server.ActivateAndServe(); err != nil { - log.Panic(err) + server, err := server6.NewServer(laddr, handler) + if err != nil { + log.Fatal(err) } + + server.Serve() } diff --git a/examples/server6/server6 b/examples/server6/server6 deleted file mode 100755 index 16f61533..00000000 Binary files a/examples/server6/server6 and /dev/null differ