diff --git a/pkg/sentry/socket/netlink/nlmsg/message.go b/pkg/sentry/socket/netlink/nlmsg/message.go index cf5d0d27af..e554e09a90 100644 --- a/pkg/sentry/socket/netlink/nlmsg/message.go +++ b/pkg/sentry/socket/netlink/nlmsg/message.go @@ -410,49 +410,49 @@ func (v *BytesView) Int32() (int32, bool) { // NetToHostU16 converts a uint16 in network byte order to // host byte order value. func NetToHostU16(v uint16) uint16 { - b := make([]byte, 2) - binary.NativeEndian.PutUint16(b, v) - return binary.BigEndian.Uint16(b) + var b [2]byte + binary.BigEndian.PutUint16(b[:], v) + return binary.NativeEndian.Uint16(b[:]) } // NetToHostU32 converts a uint32 in network byte order to // host byte order value. func NetToHostU32(v uint32) uint32 { - b := make([]byte, 4) - binary.NativeEndian.PutUint32(b, v) - return binary.BigEndian.Uint32(b) + var b [4]byte + binary.BigEndian.PutUint32(b[:], v) + return binary.NativeEndian.Uint32(b[:]) } // NetToHostU64 converts a uint64 in network byte order to // host byte order value. func NetToHostU64(v uint64) uint64 { - b := make([]byte, 8) - binary.NativeEndian.PutUint64(b, v) - return binary.BigEndian.Uint64(b) + var b [8]byte + binary.BigEndian.PutUint64(b[:], v) + return binary.NativeEndian.Uint64(b[:]) } // HostToNetU16 converts a uint16 in host byte order to // network byte order value. func HostToNetU16(v uint16) uint16 { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, v) - return binary.NativeEndian.Uint16(b) + var b [2]byte + binary.NativeEndian.PutUint16(b[:], v) + return binary.BigEndian.Uint16(b[:]) } // HostToNetU32 converts a uint32 in host byte order to // network byte order value. func HostToNetU32(v uint32) uint32 { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, v) - return binary.NativeEndian.Uint32(b) + var b [4]byte + binary.NativeEndian.PutUint32(b[:], v) + return binary.BigEndian.Uint32(b[:]) } // HostToNetU64 converts a uint64 in host byte order to // network byte order value. func HostToNetU64(v uint64) uint64 { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, v) - return binary.NativeEndian.Uint64(b) + var b [8]byte + binary.NativeEndian.PutUint64(b[:], v) + return binary.BigEndian.Uint64(b[:]) } // PutU16 converts a uint16 to network byte order and returns it as a diff --git a/pkg/sentry/socket/netlink/nlmsg/message_test.go b/pkg/sentry/socket/netlink/nlmsg/message_test.go index 5003267972..4ba142a206 100644 --- a/pkg/sentry/socket/netlink/nlmsg/message_test.go +++ b/pkg/sentry/socket/netlink/nlmsg/message_test.go @@ -16,6 +16,7 @@ package message_test import ( "bytes" + "fmt" "reflect" "testing" @@ -373,3 +374,93 @@ func TestBytesView(t *testing.T) { } } } + +func TestHostToNet(t *testing.T) { + tests := []struct { + name string + validate func() error + }{ + { + name: "U16", + validate: func() error { + v := uint16(0x1234) + if got, want := nlmsg.HostToNetU16(v), uint16(0x3412); got != want { + return fmt.Errorf("HostToNetU16(0x%x) = 0x%x, want: 0x%x", v, got, want) + } + return nil + }, + }, + { + name: "U32", + validate: func() error { + v := uint32(0x12345678) + if got, want := nlmsg.HostToNetU32(v), uint32(0x78563412); got != want { + return fmt.Errorf("HostToNetU32(0x%x) = 0x%x, want: 0x%x", v, got, want) + } + return nil + }, + }, + { + name: "U64", + validate: func() error { + v := uint64(0x123456789abcdef) + if got, want := nlmsg.HostToNetU64(v), uint64(0xefcdab8967452301); got != want { + return fmt.Errorf("HostToNetU64(0x%x) = 0x%x, want: 0x%x", v, got, want) + } + return nil + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := test.validate(); err != nil { + t.Error(err) + } + }) + } +} + +func TestNetToHost(t *testing.T) { + tests := []struct { + name string + validate func() error + }{ + { + name: "U16", + validate: func() error { + v := uint16(0x1234) + if got, want := nlmsg.NetToHostU16(v), uint16(0x3412); got != want { + return fmt.Errorf("NetToHostU16(0x%x) = 0x%x, want: 0x%x", v, got, want) + } + return nil + }, + }, + { + name: "U32", + validate: func() error { + v := uint32(0x12345678) + if got, want := nlmsg.NetToHostU32(v), uint32(0x78563412); got != want { + return fmt.Errorf("NetToHostU32(0x%x) = 0x%x, want: 0x%x", v, got, want) + } + return nil + }, + }, + { + name: "U64", + validate: func() error { + v := uint64(0x123456789abcdef) + if got, want := nlmsg.NetToHostU64(v), uint64(0xefcdab8967452301); got != want { + return fmt.Errorf("NetToHostU64(0x%x) = 0x%x, want: 0x%x", v, got, want) + } + return nil + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := test.validate(); err != nil { + t.Error(err) + } + }) + } +}