Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to nclient4 and nclient6 from our testing phase #548

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions dhcpv4/nclient4/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.12
// +build go1.12

// Package nclient4 is a small, minimum-functionality client for DHCPv4.
Expand Down Expand Up @@ -62,6 +63,12 @@ var (

// ErrNoIfaceHWAddr is returned when NewWithConn is called with nil-value as ifaceHWAddr
ErrNoIfaceHWAddr = errors.New("ifaceHWAddr is nil")

// ErrNotAUDPSocket is returned when WithReadBuffer is called but the underlying socket is not UDP.
ErrNotAUDPSocket = errors.New("the underlying socket is not UDP")

// ErrZeroReadBufferSize is returned when WithReadBuffer is called with a buffer size of zero.
ErrZeroReadBufferSize = errors.New("read buffer size cannot be zero")
)

// pendingCh is a channel associated with a pending TransactionID.
Expand Down Expand Up @@ -362,6 +369,26 @@ func WithUnicast(srcAddr *net.UDPAddr) ClientOpt {
}
}

// WithReadBuffer sets the size of the read buffer for the underlying socket.
// This has the effect of setting the SO_RCVBUF option to the given value.
// The underlying socket must be UDP.
// The buffer size must be a positive integer.
func WithReadBuffer(bufSize uint) ClientOpt {
return func(c *Client) (err error) {
if bufSize == 0 {
return ErrZeroReadBufferSize
}
udpConn, ok := c.conn.(*net.UDPConn)
if !ok {
return ErrNotAUDPSocket
}
if err := udpConn.SetReadBuffer(int(bufSize)); err != nil {
return fmt.Errorf("unable to set read buffer: %w", err)
}
return
}
}

// WithHWAddr tells to the Client to receive messages destinated to selected
// hardware address
func WithHWAddr(hwAddr net.HardwareAddr) ClientOpt {
Expand Down
82 changes: 58 additions & 24 deletions dhcpv6/nclient6/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type pendingCh struct {
done <-chan struct{}

// ch is used by the receive loop to distribute DHCP messages.
ch chan<- *dhcpv6.Message
ch chan<- dhcpv6.DHCPv6
}

// Client is a DHCPv6 client.
Expand Down Expand Up @@ -84,13 +84,13 @@ type Client struct {

type logger interface {
Printf(format string, v ...interface{})
PrintMessage(prefix string, message *dhcpv6.Message)
PrintMessage(prefix string, message dhcpv6.DHCPv6)
}

type emptyLogger struct{}

func (e emptyLogger) Printf(format string, v ...interface{}) {}
func (e emptyLogger) PrintMessage(prefix string, message *dhcpv6.Message) {}
func (e emptyLogger) Printf(format string, v ...interface{}) {}
func (e emptyLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) {}

type shortSummaryLogger struct {
*log.Logger
Expand All @@ -99,7 +99,7 @@ type shortSummaryLogger struct {
func (s shortSummaryLogger) Printf(format string, v ...interface{}) {
s.Logger.Printf(format, v...)
}
func (s shortSummaryLogger) PrintMessage(prefix string, message *dhcpv6.Message) {
func (s shortSummaryLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) {
s.Printf("%s: %s", prefix, message)
}

Expand All @@ -110,7 +110,7 @@ type debugLogger struct {
func (d debugLogger) Printf(format string, v ...interface{}) {
d.Logger.Printf(format, v...)
}
func (d debugLogger) PrintMessage(prefix string, message *dhcpv6.Message) {
func (d debugLogger) PrintMessage(prefix string, message dhcpv6.DHCPv6) {
d.Printf("%s: %s", prefix, message.Summary())
}

Expand Down Expand Up @@ -218,7 +218,7 @@ func (c *Client) receiveLoop() {
return
}

msg, err := dhcpv6.MessageFromBytes(b[:n])
msg, err := dhcpv6.FromBytes(b[:n])
if err != nil {
// Not a valid DHCP packet; keep listening.
if c.printDropped {
Expand All @@ -230,13 +230,24 @@ func (c *Client) receiveLoop() {
continue
}

inner, err := msg.GetInnerMessage()
if err != nil {
if c.printDropped {
if len(b) > 12 {
b = b[:12]
}
c.logger.Printf("Invalid DHCPv6 message received (len %d bytes), first 12 bytes: %#x", n, b)
}
continue
}

c.pendingMu.Lock()
p, ok := c.pending[msg.TransactionID]
p, ok := c.pending[inner.TransactionID]
if ok {
select {
case <-p.done:
close(p.ch)
delete(c.pending, msg.TransactionID)
delete(c.pending, inner.TransactionID)

// This send may block.
case p.ch <- msg:
Expand Down Expand Up @@ -355,14 +366,19 @@ func (c *Client) RapidSolicit(ctx context.Context, modifiers ...dhcpv6.Modifier)
return nil, err
}

switch msg.MessageType {
inner, err := msg.GetInnerMessage()
if err != nil {
return nil, err
}

switch msg.Type() {
case dhcpv6.MessageTypeReply:
// We got RapidCommitted.
return msg, nil
return inner, nil

case dhcpv6.MessageTypeAdvertise:
// We didn't get RapidCommitted. Request regular lease.
return c.Request(ctx, msg, modifiers...)
return c.Request(ctx, inner, modifiers...)

default:
return nil, fmt.Errorf("invalid message type: cannot happen")
Expand All @@ -380,7 +396,11 @@ func (c *Client) Solicit(ctx context.Context, modifiers ...dhcpv6.Modifier) (*dh
if err != nil {
return nil, err
}
return msg, nil
inner, err := msg.GetInnerMessage()
if err != nil {
return nil, err
}
return inner, nil
}

// Request requests an IP Assignment from peer given an advertise message.
Expand All @@ -389,7 +409,12 @@ func (c *Client) Request(ctx context.Context, advertise *dhcpv6.Message, modifie
if err != nil {
return nil, err
}
return c.SendAndRead(ctx, c.serverAddr, request, nil)
msg, err := c.SendAndRead(ctx, c.serverAddr, request, nil)
inner, err := msg.GetInnerMessage()
if err != nil {
return nil, err
}
return inner, nil
}

// send sends p to destination and returns a response channel.
Expand All @@ -398,16 +423,21 @@ func (c *Client) Request(ctx context.Context, advertise *dhcpv6.Message, modifie
// received.
//
// Responses will be matched by transaction ID.
func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Message, func(), error) {
func (c *Client) send(dest net.Addr, msg dhcpv6.DHCPv6) (<-chan dhcpv6.DHCPv6, func(), error) {
inner, err := msg.GetInnerMessage()
if err != nil {
return nil, nil, err
}

c.pendingMu.Lock()
if _, ok := c.pending[msg.TransactionID]; ok {
if _, ok := c.pending[inner.TransactionID]; ok {
c.pendingMu.Unlock()
return nil, nil, fmt.Errorf("transaction ID %s already in use", msg.TransactionID)
return nil, nil, fmt.Errorf("transaction ID %s already in use", inner.TransactionID)
}

ch := make(chan *dhcpv6.Message, c.bufferCap)
ch := make(chan dhcpv6.DHCPv6, c.bufferCap)
done := make(chan struct{})
c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch}
c.pending[inner.TransactionID] = &pendingCh{done: done, ch: ch}
c.pendingMu.Unlock()

cancel := func() {
Expand All @@ -420,9 +450,9 @@ func (c *Client) send(dest net.Addr, msg *dhcpv6.Message) (<-chan *dhcpv6.Messag
close(done)

c.pendingMu.Lock()
if p, ok := c.pending[msg.TransactionID]; ok {
if p, ok := c.pending[inner.TransactionID]; ok {
close(p.ch)
delete(c.pending, msg.TransactionID)
delete(c.pending, inner.TransactionID)
}
c.pendingMu.Unlock()
}
Expand All @@ -441,8 +471,8 @@ var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded")
// response matching `match` as well as its Transaction ID.
//
// If match is nil, the first packet matching the Transaction ID is returned.
func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6.Message, match Matcher) (*dhcpv6.Message, error) {
var response *dhcpv6.Message
func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg dhcpv6.DHCPv6, match Matcher) (dhcpv6.DHCPv6, error) {
var response dhcpv6.DHCPv6
err := c.retryFn(func(timeout time.Duration) error {
ch, rem, err := c.send(dest, msg)
if err != nil {
Expand All @@ -463,7 +493,11 @@ func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, msg *dhcpv6
return ctx.Err()

case packet := <-ch:
if match == nil || match(packet) {
inner, err := packet.GetInnerMessage()
if err != nil {
return err
}
if match == nil || match(inner) {
c.logger.PrintMessage("received message", packet)
response = packet
return nil
Expand Down
Loading