Skip to content

Commit a7e4280

Browse files
committed
updates
1 parent f412b2b commit a7e4280

File tree

2 files changed

+73
-109
lines changed

2 files changed

+73
-109
lines changed

serialization/serialization.go

+63-98
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"encoding/binary"
1111
"errors"
1212
"fmt"
13-
"io"
1413
"math/bits"
1514
"slices"
1615
"strings"
@@ -120,61 +119,42 @@ func (f FieldString) String() string {
120119
}
121120

122121
func Unmarshal(data []byte, v interface{}) error {
123-
r := bytes.NewReader(data)
124122
switch m := v.(type) {
125123
case *Message:
126-
messageLen := 1
127-
tmpVer := make([]byte, messageLen)
128-
_, err := r.Read(tmpVer)
129-
if err != nil {
130-
return err
131-
}
132-
m.Version = tmpVer[0] >> 1
133-
134-
err = Unmarshal(data[messageLen:], &m.Format)
124+
m.Version = data[0] >> 1
125+
err := Unmarshal(data[1:], &m.Format)
135126
if err != nil {
136127
return err
137128
}
138129
case *Format:
139-
formatLen := 2
140-
tmpFormat := make([]byte, formatLen)
141-
_, err := r.Read(tmpFormat)
142-
if err != nil {
143-
return err
144-
}
145-
m.Size = uint64(tmpFormat[0] >> 1)
146-
m.LastNonIgnorableField = int(tmpFormat[1] >> 1)
130+
pos := uint64(0)
131+
m.Size = uint64(data[pos] >> 1)
132+
pos++
133+
m.LastNonIgnorableField = int(data[pos] >> 1)
134+
pos++
147135

148136
for i := 0; i < len(m.Fields); i++ {
149-
tmpField := make([]byte, 1)
150-
_, err := r.Read(tmpField)
151-
if err != nil {
152-
if errors.Is(err, io.EOF) {
153-
break
154-
}
155-
return err
156-
}
157-
if int(tmpField[0]/2) != i {
137+
if int(pos)+1 > len(data) || int(data[pos]>>1) != i {
158138
// The field number we got doesn't match what we expect,
159-
// so a field was skipped. Rewind the reader and skip.
139+
// so a field was skipped.
160140
m.Fields[i].ID = i
161141
m.Fields[i].Skipped = true
162-
_, err := r.Seek(-1, io.SeekCurrent)
163-
if err != nil {
164-
return err
165-
}
166142
continue
167143
}
168-
m.Fields[i].ID = int(tmpField[0] >> 1)
144+
m.Fields[i].ID = int(data[pos] >> 1)
145+
pos++
146+
var n uint64
147+
var err error
169148
switch f := m.Fields[i].Type.(type) {
170149
case FieldIntFixed:
171-
f.Value, err = decodeFixed(r, f.Length)
150+
f.Value, n, err = decodeFixed(data, pos, f.Length)
172151
if err != nil {
173152
return err
174153
}
175154
m.Fields[i].Type = f
176155
case FieldUintVar:
177-
val, err := decodeVar(r, true)
156+
var val interface{}
157+
val, n, err = decodeVar(data, pos, true)
178158
if err != nil {
179159
return err
180160
}
@@ -185,7 +165,8 @@ func Unmarshal(data []byte, v interface{}) error {
185165
}
186166
m.Fields[i].Type = f
187167
case FieldIntVar:
188-
val, err := decodeVar(r, false)
168+
var val interface{}
169+
val, n, err = decodeVar(data, pos, false)
189170
if err != nil {
190171
return err
191172
}
@@ -196,14 +177,15 @@ func Unmarshal(data []byte, v interface{}) error {
196177
}
197178
m.Fields[i].Type = f
198179
case FieldString:
199-
f.Value, err = decodeString(r)
180+
f.Value, n, err = decodeString(data, pos)
200181
if err != nil {
201182
return err
202183
}
203184
m.Fields[i].Type = f
204185
default:
205186
return fmt.Errorf("unsupported field type: %T", m.Fields[i].Type)
206187
}
188+
pos = n
207189
}
208190

209191
default:
@@ -212,105 +194,88 @@ func Unmarshal(data []byte, v interface{}) error {
212194
return nil
213195
}
214196

215-
func decodeString(r io.Reader) (string, error) {
216-
firstByte := make([]byte, 1)
217-
_, err := r.Read(firstByte)
218-
if err != nil {
219-
return "", err
197+
func decodeString(data []byte, pos uint64) (string, uint64, error) {
198+
if len(data) < int(pos)+1 {
199+
return "", pos, errors.New("string truncated, expected at least one byte")
220200
}
221-
strBytes := make([]byte, firstByte[0] >> 1)
222-
n, err := r.Read(strBytes)
223-
if err != nil {
224-
return "", err
201+
strLen := int(data[pos] >> 1)
202+
pos++
203+
if len(data) < int(pos)+strLen {
204+
return "", pos, fmt.Errorf("string truncated, expected length: %d", strLen)
225205
}
226-
if n != int(firstByte[0] >> 1) {
227-
return "", fmt.Errorf("only read %d bytes, expected %d", n, firstByte[0]/2)
228-
}
229-
return string(strBytes), nil
206+
return string(data[pos : pos+uint64(strLen)]), pos + uint64(strLen), nil
230207
}
231208

232-
func decodeFixed(r io.Reader, len int) ([]byte, error) {
209+
func decodeFixed(data []byte, pos uint64, intlen int) ([]byte, uint64, error) {
233210
var b bytes.Buffer
234211

235-
tmpInt := make([]byte, 1)
236212
for {
237-
_, err := r.Read(tmpInt)
238-
if err != nil {
239-
return nil, err
213+
if len(data) < int(pos)+1 {
214+
return b.Bytes(), pos, errors.New("data truncated")
240215
}
241-
if tmpInt[0]%2 == 0 {
242-
b.WriteByte(tmpInt[0] >> 1)
216+
if data[pos]%2 == 0 {
217+
b.WriteByte(data[pos] >> 1)
243218
} else {
244-
tmpInt2 := make([]byte, 1)
245-
_, err := r.Read(tmpInt2)
246-
if err != nil {
247-
return nil, err
219+
if len(data) < int(pos)+2 {
220+
return b.Bytes(), pos, errors.New("data truncated")
248221
}
249-
switch tmpInt2[0] {
222+
switch data[pos+1] {
250223
case 0x2:
251-
b.WriteByte((tmpInt[0] >> 2) + 0x80)
224+
b.WriteByte((data[pos] >> 2) + 0x80)
252225
case 0x3:
253-
b.WriteByte((tmpInt[0] >> 2) + 0xc0)
226+
b.WriteByte((data[pos] >> 2) + 0xc0)
254227
default:
255-
return nil, fmt.Errorf("unknown decoding for %v", tmpInt2[0])
228+
return nil, pos, fmt.Errorf("unknown decoding for %v", data[pos])
256229
}
230+
pos++
257231
}
258-
if b.Len() == len {
232+
pos++
233+
if b.Len() == intlen {
259234
break
260235
}
261236
}
262-
return b.Bytes(), nil
237+
return b.Bytes(), pos, nil
263238
}
264239

265-
func decodeVar(r io.ReadSeeker, unsigned bool) (interface{}, error) {
266-
firstByte := make([]byte, 1)
267-
_, err := r.Read(firstByte)
268-
if err != nil {
269-
return 0, err
270-
}
271-
tb := trailingOneBitCount(firstByte[0])
272-
_, err = r.Seek(-1, io.SeekCurrent)
273-
if err != nil {
274-
return 0, err
275-
}
276-
fieldBytes := make([]byte, tb+1)
277-
n, err := r.Read(fieldBytes)
278-
if err != nil {
279-
return 0, err
240+
func decodeVar(data []byte, pos uint64, unsigned bool) (interface{}, uint64, error) {
241+
if len(data) < int(pos)+1 {
242+
return 0, pos, errors.New("data truncated")
280243
}
281-
if n != tb+1 {
282-
return 0, fmt.Errorf("only read %d bytes, expected %d", n, tb+1)
244+
flen := trailingOneBitCount(data[pos]) + 1
245+
if len(data) < int(pos)+flen {
246+
return 0, pos, fmt.Errorf("truncated data, expected length: %d", flen)
283247
}
284248
var tNum uint64
285-
switch len(fieldBytes) {
249+
switch flen {
286250
case 1:
287-
tNum = uint64(fieldBytes[0])
251+
tNum = uint64(data[pos])
288252
case 2:
289-
tNum = uint64(binary.LittleEndian.Uint16(fieldBytes))
253+
tNum = uint64(binary.LittleEndian.Uint16(data[pos : int(pos)+flen]))
290254
case 3:
291255
tNum = uint64(binary.LittleEndian.Uint32(
292-
slices.Concat(fieldBytes, []byte{0x0})))
256+
slices.Concat(data[pos:int(pos)+flen], []byte{0x0})))
293257
case 4:
294-
tNum = uint64(binary.LittleEndian.Uint32(fieldBytes))
258+
tNum = uint64(binary.LittleEndian.Uint32(data[pos : int(pos)+flen]))
295259
case 5:
296260
tNum = binary.LittleEndian.Uint64(
297-
slices.Concat(fieldBytes, []byte{0x0, 0x0, 0x0}))
261+
slices.Concat(data[pos:int(pos)+flen], []byte{0x0, 0x0, 0x0}))
298262
case 6:
299263
tNum = binary.LittleEndian.Uint64(
300-
slices.Concat(fieldBytes, []byte{0x0, 0x0}))
264+
slices.Concat(data[pos:int(pos)+flen], []byte{0x0, 0x0}))
301265
case 7:
302266
tNum = binary.LittleEndian.Uint64(
303-
slices.Concat(fieldBytes, []byte{0x0}))
267+
slices.Concat(data[pos:int(pos)+flen], []byte{0x0}))
304268
case 8:
305-
tNum = binary.LittleEndian.Uint64(fieldBytes)
269+
tNum = binary.LittleEndian.Uint64(data[pos : int(pos)+flen])
306270
}
271+
pos += uint64(flen)
307272
if unsigned {
308-
return tNum >> (tb + 1), nil
273+
return tNum >> flen, pos, nil
309274
}
310-
if positive := (tNum>>(tb+1))&1 == 0; positive {
311-
return int64(tNum >> (tb + 2)), nil
275+
if positive := (tNum>>flen)&1 == 0; positive {
276+
return int64(tNum >> (flen + 1)), pos, nil
312277
}
313-
return int64(-(1 + (tNum >> (tb + 2)))), nil
278+
return int64(-(1 + (tNum >> (flen + 1)))), pos, nil
314279
}
315280

316281
func trailingOneBitCount(b byte) int {

serialization/serialization_test.go

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package serialization
22

33
import (
4-
"bytes"
54
"testing"
65

76
"github.com/stretchr/testify/require"
@@ -49,13 +48,13 @@ func TestDecodeFixed(t *testing.T) {
4948
[]byte{0xee, 0x81},
5049
16,
5150
[]byte{},
52-
"EOF",
51+
"data truncated",
5352
},
5453
{
5554
[]byte{},
5655
16,
5756
[]byte{},
58-
"EOF",
57+
"data truncated",
5958
},
6059
{
6160
[]byte{0xee, 0x81, 0x04, 0xc1, 0x02, 0x01, 0x03, 0x41, 0x03, 0x81, 0x03, 0xc1, 0x03, 0xc5, 0x03, 0x22,
@@ -67,7 +66,7 @@ func TestDecodeFixed(t *testing.T) {
6766
}
6867

6968
for _, tc := range testcases {
70-
actual, err := decodeFixed(bytes.NewReader(tc.input), tc.len)
69+
actual, _, err := decodeFixed(tc.input, 0, tc.len)
7170
if tc.err == "" {
7271
require.NoError(t, err)
7372
require.Equal(t, tc.result, actual)
@@ -92,22 +91,22 @@ func TestDecodeString(t *testing.T) {
9291
{
9392
[]byte{0x18, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67},
9493
"",
95-
"only read ",
94+
"string truncated",
9695
},
9796
{
9897
[]byte{},
9998
"",
100-
"EOF",
99+
"string truncated, expected at least one byte",
101100
},
102101
{
103102
[]byte{0x18},
104103
"",
105-
"EOF",
104+
"string truncated, expected length",
106105
},
107106
}
108107

109108
for _, tc := range testcases {
110-
s, err := decodeString(bytes.NewReader(tc.input))
109+
s, _, err := decodeString(tc.input, 0)
111110
if tc.err == "" {
112111
require.NoError(t, err)
113112
require.Equal(t, tc.result, s)
@@ -128,13 +127,13 @@ func TestDecodeVar(t *testing.T) {
128127
[]byte{},
129128
false,
130129
0,
131-
"EOF",
130+
"data truncated",
132131
},
133132
{
134133
[]byte{0xd9},
135134
false,
136135
0,
137-
"only read ",
136+
"truncated data",
138137
},
139138
{
140139
[]byte{0x4},
@@ -205,7 +204,7 @@ func TestDecodeVar(t *testing.T) {
205204
}
206205

207206
for _, tc := range testcases {
208-
r, err := decodeVar(bytes.NewReader(tc.input), tc.unsigned)
207+
r, _, err := decodeVar(tc.input, 0, tc.unsigned)
209208
if tc.err == "" {
210209
require.NoError(t, err)
211210
require.Equal(t, tc.result, r, tc.result)

0 commit comments

Comments
 (0)