diff --git a/dhcpv4/dhcpv4.go b/dhcpv4/dhcpv4.go index a8754312..4e7e71a9 100644 --- a/dhcpv4/dhcpv4.go +++ b/dhcpv4/dhcpv4.go @@ -300,6 +300,13 @@ func NewReleaseFromACK(ack *DHCPv4, modifiers ...Modifier) (*DHCPv4, error) { // FromBytes decodes a DHCPv4 packet from a sequence of bytes, and returns an // error if the packet is not valid. func FromBytes(q []byte) (*DHCPv4, error) { + return FromBytesWithStrictPadding(q, false) +} + +// FromBytesWithStrictPadding decodes a DHCPv4 packet from a sequence of bytes, and returns an +// error if the packet is not valid. +// Octets after the End option are checked or not according to pad option. +func FromBytesWithStrictPadding(q []byte, strictPadding bool) (*DHCPv4, error) { var p DHCPv4 buf := uio.NewBigEndianBuffer(q) @@ -353,7 +360,7 @@ func FromBytes(q []byte) (*DHCPv4, error) { } p.Options = make(Options) - if err := p.Options.fromBytesCheckEnd(buf.Data(), true); err != nil { + if err := p.Options.fromBytesWithStrictPadding(buf.Data(), true, strictPadding); err != nil { return nil, err } return &p, nil diff --git a/dhcpv4/nclient4/client.go b/dhcpv4/nclient4/client.go index b4e4b567..e5972d41 100644 --- a/dhcpv4/nclient4/client.go +++ b/dhcpv4/nclient4/client.go @@ -143,6 +143,11 @@ type Client struct { // bufferCap is the channel capacity for each TransactionID. bufferCap int + // strictPadding is a flag that allows you to configure packet parsing behavior. + // The RFC says that octets after the End option SHOULD be pad options. Not MUST. + // But if necessary, then using this parameter you can consider such packages invalid. + strictPadding bool + // serverAddr is the UDP address to send all packets to. // // This may be an actual broadcast address, or a unicast address. @@ -267,7 +272,7 @@ func (c *Client) receiveLoop() { return } - msg, err := dhcpv4.FromBytes(b[:n]) + msg, err := dhcpv4.FromBytesWithStrictPadding(b[:n], c.strictPadding) if err != nil { // Not a valid DHCP packet; keep listening. continue @@ -396,6 +401,16 @@ func WithServerAddr(n *net.UDPAddr) ClientOpt { } } +// WithStrictPadding is an option that allows you to configure packet parsing behavior. +// The RFC says that octets after the End option SHOULD be pad options. Not MUST. +// But if necessary, then using this option you can consider such packages invalid. +func WithStrictPadding() ClientOpt { + return func(c *Client) (err error) { + c.strictPadding = true + return + } +} + // Matcher matches DHCP packets. type Matcher func(*dhcpv4.DHCPv4) bool diff --git a/dhcpv4/nclient4/client_test.go b/dhcpv4/nclient4/client_test.go index df851abb..ad194a4b 100644 --- a/dhcpv4/nclient4/client_test.go +++ b/dhcpv4/nclient4/client_test.go @@ -72,6 +72,40 @@ func serveAndClient(ctx context.Context, responses [][]*dhcpv4.DHCPv4, opts ...C return mc, serverConn } +func serveAndClientWithBytesResp(response []byte, opts ...ClientOpt) *Client { + // Fake PacketConn connection. + clientRawConn, serverRawConn, err := socketpair.PacketSocketPair() + if err != nil { + panic(err) + } + + clientConn := NewBroadcastUDPConn(clientRawConn, &net.UDPAddr{Port: ClientPort}) + serverConn := NewBroadcastUDPConn(serverRawConn, &net.UDPAddr{Port: ServerPort}) + + o := []ClientOpt{WithRetry(1), WithTimeout(2 * time.Second)} + o = append(o, opts...) + mc, err := NewWithConn(clientConn, net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...) + if err != nil { + panic(err) + } + + // Fake server. + go func() { + b := make([]byte, 4096) + _, peer, err := serverConn.ReadFrom(b) + if err != nil { + panic(err) + } + + _, err = serverConn.WriteTo(response, peer) + if err != nil { + panic(err) + } + }() + + return mc +} + func ComparePacket(got *dhcpv4.DHCPv4, want *dhcpv4.DHCPv4) error { if got == nil && got == want { return nil @@ -281,6 +315,62 @@ func TestSimpleSendAndReadDiscardGarbage(t *testing.T) { } } +func TestSendAndReadWithDefaultPadding(t *testing.T) { + pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) + + // Both the server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + response, err := dhcpv4.NewReplyFromRequest(pkt) + if err != nil { + t.Errorf("NewReplyFromRequest(%v) = %v, want nil", pkt, err) + } + + bytes := response.ToBytes() + + // Add garbage to the end. + bytes[len(bytes)-1] = 0x01 + + mc := serveAndClientWithBytesResp(bytes) + defer mc.Close() + + rcvd, err := mc.SendAndRead(ctx, DefaultServers, pkt, nil) + if err != nil { + t.Errorf("SendAndRead(%v) = %v, want nil", pkt, err) + } + + if err := ComparePacket(rcvd, response); err != nil { + t.Errorf("got unexpected packets: %v", err) + } +} + +func TestSendAndReadWithStrictPadding(t *testing.T) { + pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) + + // Both the server and client only get 2 seconds. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + response, err := dhcpv4.NewReplyFromRequest(pkt) + if err != nil { + t.Errorf("NewReplyFromRequest(%v) = %v, want nil", pkt, err) + } + + bytes := response.ToBytes() + + // Add garbage to the end. + bytes[len(bytes)-1] = 0x01 + + mc := serveAndClientWithBytesResp(bytes, WithStrictPadding()) + defer mc.Close() + + _, err = mc.SendAndRead(ctx, DefaultServers, pkt, nil) + if err == nil { + t.Errorf("SendAndRead(%v) error is nil, want error", pkt) + } +} + func TestMultipleSendAndRead(t *testing.T) { for _, tt := range []struct { desc string diff --git a/dhcpv4/options.go b/dhcpv4/options.go index b31a14b1..42f06e17 100644 --- a/dhcpv4/options.go +++ b/dhcpv4/options.go @@ -106,7 +106,7 @@ func (o Options) ToBytes() []byte { // // Returns an error if any invalid option or length is found. func (o Options) FromBytes(data []byte) error { - return o.fromBytesCheckEnd(data, false) + return o.fromBytesWithStrictPadding(data, false, false) } const ( @@ -115,9 +115,9 @@ const ( optEnd = 255 ) -// FromBytesCheckEnd parses Options from byte sequences using the +// fromBytesWithStrictPadding parses Options from byte sequences using the // parsing function that is passed in as a paremeter -func (o Options) fromBytesCheckEnd(data []byte, checkEndOption bool) error { +func (o Options) fromBytesWithStrictPadding(data []byte, checkEndOption bool, strictPadding bool) error { if len(data) == 0 { return nil } @@ -161,6 +161,10 @@ func (o Options) fromBytesCheckEnd(data []byte, checkEndOption bool) error { return io.ErrUnexpectedEOF } + if !strictPadding { + return nil + } + // Any bytes left must be padding. var pad uint8 for buf.Len() >= 1 { diff --git a/dhcpv4/options_test.go b/dhcpv4/options_test.go index ff11f888..1f85cc5f 100644 --- a/dhcpv4/options_test.go +++ b/dhcpv4/options_test.go @@ -227,9 +227,10 @@ func TestOptionsMarshal(t *testing.T) { func TestOptionsUnmarshal(t *testing.T) { for i, tt := range []struct { - input []byte - want Options - wantError bool + input []byte + strictPadding bool + want Options + wantError bool }{ { // Buffer missing data. @@ -255,9 +256,15 @@ func TestOptionsUnmarshal(t *testing.T) { wantError: true, }, { - // Option present after the End is a nono. - input: []byte{byte(OptionEnd), 3}, - wantError: true, + // Option present after the End if not strictPadding. + input: []byte{byte(OptionEnd), 3}, + want: Options{}, + }, + { + // Option present after the End if strictPadding. + input: []byte{byte(OptionEnd), 3}, + strictPadding: true, + wantError: true, }, { input: []byte{byte(OptionEnd)}, @@ -306,7 +313,7 @@ func TestOptionsUnmarshal(t *testing.T) { } { t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) { opt := make(Options) - err := opt.fromBytesCheckEnd(tt.input, true) + err := opt.fromBytesWithStrictPadding(tt.input, true, tt.strictPadding) if tt.wantError { require.Error(t, err) } else {