Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strict padding setting to the nclient4 #545

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dhcpv4/dhcpv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,13 @@ func NewReleaseFromACK(ack *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) {
// FromBytes decodes a DHCPv4 packet from a sequence of bytes, and returns an
// error if the packet is not valid.
func FromBytes(q []byte) (*DHCPv4, error) {
return FromBytesWithStrictPadding(q, false)
}

// FromBytesWithStrictPadding decodes a DHCPv4 packet from a sequence of bytes, and returns an
// error if the packet is not valid.
// Octets after the End option are checked or not according to pad option.
func FromBytesWithStrictPadding(q []byte, strictPadding bool) (*DHCPv4, error) {
var p DHCPv4
buf := uio.NewBigEndianBuffer(q)

Expand Down Expand Up @@ -353,7 +360,7 @@ func FromBytes(q []byte) (*DHCPv4, error) {
}

p.Options = make(Options)
if err := p.Options.fromBytesCheckEnd(buf.Data(), true); err != nil {
if err := p.Options.fromBytesWithStrictPadding(buf.Data(), true, strictPadding); err != nil {
return nil, err
}
return &p, nil
Expand Down
17 changes: 16 additions & 1 deletion dhcpv4/nclient4/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ type Client struct {
// bufferCap is the channel capacity for each TransactionID.
bufferCap int

// strictPadding is a flag that allows you to configure packet parsing behavior.
// The RFC says that octets after the End option SHOULD be pad options. Not MUST.
// But if necessary, then using this parameter you can consider such packages invalid.
strictPadding bool

// serverAddr is the UDP address to send all packets to.
//
// This may be an actual broadcast address, or a unicast address.
Expand Down Expand Up @@ -267,7 +272,7 @@ func (c *Client) receiveLoop() {
return
}

msg, err := dhcpv4.FromBytes(b[:n])
msg, err := dhcpv4.FromBytesWithStrictPadding(b[:n], c.strictPadding)
if err != nil {
// Not a valid DHCP packet; keep listening.
continue
Expand Down Expand Up @@ -396,6 +401,16 @@ func WithServerAddr(n *net.UDPAddr) ClientOpt {
}
}

// WithStrictPadding is an option that allows you to configure packet parsing behavior.
// The RFC says that octets after the End option SHOULD be pad options. Not MUST.
// But if necessary, then using this option you can consider such packages invalid.
func WithStrictPadding() ClientOpt {
return func(c *Client) (err error) {
c.strictPadding = true
return
}
}

// Matcher matches DHCP packets.
type Matcher func(*dhcpv4.DHCPv4) bool

Expand Down
90 changes: 90 additions & 0 deletions dhcpv4/nclient4/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,40 @@ func serveAndClient(ctx context.Context, responses [][]*dhcpv4.DHCPv4, opts ...C
return mc, serverConn
}

func serveAndClientWithBytesResp(response []byte, opts ...ClientOpt) *Client {
// Fake PacketConn connection.
clientRawConn, serverRawConn, err := socketpair.PacketSocketPair()
if err != nil {
panic(err)
}

clientConn := NewBroadcastUDPConn(clientRawConn, &net.UDPAddr{Port: ClientPort})
serverConn := NewBroadcastUDPConn(serverRawConn, &net.UDPAddr{Port: ServerPort})

o := []ClientOpt{WithRetry(1), WithTimeout(2 * time.Second)}
o = append(o, opts...)
mc, err := NewWithConn(clientConn, net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...)
if err != nil {
panic(err)
}

// Fake server.
go func() {
b := make([]byte, 4096)
_, peer, err := serverConn.ReadFrom(b)
if err != nil {
panic(err)
}

_, err = serverConn.WriteTo(response, peer)
if err != nil {
panic(err)
}
}()

return mc
}

func ComparePacket(got *dhcpv4.DHCPv4, want *dhcpv4.DHCPv4) error {
if got == nil && got == want {
return nil
Expand Down Expand Up @@ -281,6 +315,62 @@ func TestSimpleSendAndReadDiscardGarbage(t *testing.T) {
}
}

func TestSendAndReadWithDefaultPadding(t *testing.T) {
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})

// Both the server and client only get 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

response, err := dhcpv4.NewReplyFromRequest(pkt)
if err != nil {
t.Errorf("NewReplyFromRequest(%v) = %v, want nil", pkt, err)
}

bytes := response.ToBytes()

// Add garbage to the end.
bytes[len(bytes)-1] = 0x01

mc := serveAndClientWithBytesResp(bytes)
defer mc.Close()

rcvd, err := mc.SendAndRead(ctx, DefaultServers, pkt, nil)
if err != nil {
t.Errorf("SendAndRead(%v) = %v, want nil", pkt, err)
}

if err := ComparePacket(rcvd, response); err != nil {
t.Errorf("got unexpected packets: %v", err)
}
}

func TestSendAndReadWithStrictPadding(t *testing.T) {
pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33})

// Both the server and client only get 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

response, err := dhcpv4.NewReplyFromRequest(pkt)
if err != nil {
t.Errorf("NewReplyFromRequest(%v) = %v, want nil", pkt, err)
}

bytes := response.ToBytes()

// Add garbage to the end.
bytes[len(bytes)-1] = 0x01

mc := serveAndClientWithBytesResp(bytes, WithStrictPadding())
defer mc.Close()

_, err = mc.SendAndRead(ctx, DefaultServers, pkt, nil)
if err == nil {
t.Errorf("SendAndRead(%v) error is nil, want error", pkt)
}
}

func TestMultipleSendAndRead(t *testing.T) {
for _, tt := range []struct {
desc string
Expand Down
10 changes: 7 additions & 3 deletions dhcpv4/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (o Options) ToBytes() []byte {
//
// Returns an error if any invalid option or length is found.
func (o Options) FromBytes(data []byte) error {
return o.fromBytesCheckEnd(data, false)
return o.fromBytesWithStrictPadding(data, false, false)
}

const (
Expand All @@ -115,9 +115,9 @@ const (
optEnd = 255
)

// FromBytesCheckEnd parses Options from byte sequences using the
// fromBytesWithStrictPadding parses Options from byte sequences using the
// parsing function that is passed in as a paremeter
func (o Options) fromBytesCheckEnd(data []byte, checkEndOption bool) error {
func (o Options) fromBytesWithStrictPadding(data []byte, checkEndOption bool, strictPadding bool) error {
if len(data) == 0 {
return nil
}
Expand Down Expand Up @@ -161,6 +161,10 @@ func (o Options) fromBytesCheckEnd(data []byte, checkEndOption bool) error {
return io.ErrUnexpectedEOF
}

if !strictPadding {
return nil
}

// Any bytes left must be padding.
var pad uint8
for buf.Len() >= 1 {
Expand Down
21 changes: 14 additions & 7 deletions dhcpv4/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,10 @@ func TestOptionsMarshal(t *testing.T) {

func TestOptionsUnmarshal(t *testing.T) {
for i, tt := range []struct {
input []byte
want Options
wantError bool
input []byte
strictPadding bool
want Options
wantError bool
}{
{
// Buffer missing data.
Expand All @@ -255,9 +256,15 @@ func TestOptionsUnmarshal(t *testing.T) {
wantError: true,
},
{
// Option present after the End is a nono.
input: []byte{byte(OptionEnd), 3},
wantError: true,
// Option present after the End if not strictPadding.
input: []byte{byte(OptionEnd), 3},
want: Options{},
},
{
// Option present after the End if strictPadding.
input: []byte{byte(OptionEnd), 3},
strictPadding: true,
wantError: true,
},
{
input: []byte{byte(OptionEnd)},
Expand Down Expand Up @@ -306,7 +313,7 @@ func TestOptionsUnmarshal(t *testing.T) {
} {
t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
opt := make(Options)
err := opt.fromBytesCheckEnd(tt.input, true)
err := opt.fromBytesWithStrictPadding(tt.input, true, tt.strictPadding)
if tt.wantError {
require.Error(t, err)
} else {
Expand Down
Loading