Skip to content

Commit e26d1c6

Browse files
authored
Merge pull request #607 from saikocat/issue-606
Fix address-instability in MultiSegmentArena by storing *Segment and validating via (id, arena storage)
2 parents 0b34935 + 97b1468 commit e26d1c6

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

arena.go

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func (ssa *SingleSegmentArena) Release() {
157157
// buffers, allocating new buffers of exponentially-increasing size when
158158
// full. This avoids the potentially-expensive slice copying of SingleSegment.
159159
type MultiSegmentArena struct {
160-
segs []Segment
160+
segs []*Segment
161161

162162
// rawData is set when the individual segments were all demuxed from
163163
// the passed raw data slice.
@@ -204,12 +204,16 @@ func (msa *MultiSegmentArena) Release() {
204204
msa.rawData = nil
205205

206206
for i := range msa.segs {
207+
if msa.segs[i] == nil {
208+
continue
209+
}
207210
if msa.bp != nil {
208211
zeroSlice(msa.segs[i].data)
209212
msa.bp.Put(msa.segs[i].data)
210213
}
211214
msa.segs[i].data = nil
212215
msa.segs[i].BindTo(nil)
216+
msa.segs[i] = nil
213217
}
214218

215219
if msa.segs != nil {
@@ -227,15 +231,17 @@ func (msa *MultiSegmentArena) Release() {
227231
// Like MultiSegment, but doesn't use the pool
228232
func multiSegment(b [][]byte) *MultiSegmentArena {
229233
var bp *bufferpool.Pool
230-
var segs []Segment
234+
var segs []*Segment
231235
if b == nil {
232236
bp = &bufferpool.Default
233-
segs = make([]Segment, 0, 5) // Typical size.
237+
segs = make([]*Segment, 0, 5) // Typical size.
234238
} else {
235-
segs = make([]Segment, len(b))
239+
segs = make([]*Segment, len(b))
236240
for i := range b {
237-
segs[i].data = b[i]
238-
segs[i].id = SegmentID(i)
241+
segs[i] = &Segment{
242+
data: b[i],
243+
id: SegmentID(i),
244+
}
239245
}
240246
}
241247
return &MultiSegmentArena{segs: segs, bp: bp}
@@ -264,7 +270,7 @@ func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte, bp *bufferpoo
264270
msa.segs = msa.segs[:numSegs]
265271
} else {
266272
inc := numSegs - len(msa.segs)
267-
msa.segs = append(msa.segs, make([]Segment, inc)...)
273+
msa.segs = append(msa.segs, make([]*Segment, inc)...)
268274
}
269275

270276
rawData := data
@@ -273,9 +279,13 @@ func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte, bp *bufferpoo
273279
if err != nil {
274280
return err
275281
}
276-
277-
msa.segs[i].data, data = data[:sz:sz], data[sz:]
278-
msa.segs[i].id = i
282+
seg := msa.segs[i]
283+
if seg == nil {
284+
seg = &Segment{id: i}
285+
msa.segs[i] = seg
286+
}
287+
seg.data, data = data[:sz:sz], data[sz:]
288+
seg.id = i
279289
}
280290

281291
msa.rawData = rawData
@@ -291,23 +301,14 @@ func (msa *MultiSegmentArena) Segment(id SegmentID) *Segment {
291301
if int(id) >= len(msa.segs) {
292302
return nil
293303
}
294-
return &msa.segs[id]
304+
return msa.segs[id]
295305
}
296306

297307
func (msa *MultiSegmentArena) Allocate(sz Size, msg *Message, seg *Segment) (*Segment, address, error) {
298308
// Prefer allocating in seg if it has capacity.
299309
if seg != nil && hasCapacity(seg.data, sz) {
300-
// Double check this segment is part of this arena.
301-
contains := false
302-
for i := range msa.segs {
303-
if &msa.segs[i] == seg {
304-
contains = true
305-
break
306-
}
307-
}
308-
309-
if !contains {
310-
// This is a usage error.
310+
// Membership check: validate by id and exact pointer equality.
311+
if int(seg.id) >= len(msa.segs) || msa.segs[seg.id] != seg {
311312
return nil, 0, errors.New("preferred segment is not part of the arena")
312313
}
313314

@@ -325,14 +326,17 @@ func (msa *MultiSegmentArena) Allocate(sz Size, msg *Message, seg *Segment) (*Se
325326

326327
var total int64
327328
for i := range msa.segs {
329+
if msa.segs[i] == nil {
330+
continue
331+
}
328332
data := msa.segs[i].data
329333
if hasCapacity(data, sz) {
330334
// Found segment with spare capacity.
331-
addr := address(len(msa.segs[i].data))
335+
addr := address(len(data))
332336
newLen := int(addr) + int(sz)
333-
msa.segs[i].data = msa.segs[i].data[:newLen]
337+
msa.segs[i].data = data[:newLen]
334338
msa.segs[i].BindTo(msg)
335-
return &msa.segs[i], addr, nil
339+
return msa.segs[i], addr, nil
336340
}
337341

338342
if total += int64(cap(data)); total < 0 {
@@ -363,20 +367,19 @@ func (msa *MultiSegmentArena) Allocate(sz Size, msg *Message, seg *Segment) (*Se
363367
return nil, 0, err
364368
}
365369

366-
// We have determined this will be a new segment. Get the backing
367-
// buffer for it.
370+
// We have determined this will be a new segment. Get the backing buffer.
368371
buf := msa.bp.Get(n)
369372
buf = buf[:sz]
370373

371374
// Setup the segment.
372375
id := SegmentID(len(msa.segs))
373-
msa.segs = append(msa.segs, Segment{
376+
newSeg := &Segment{
374377
data: buf,
375378
id: id,
376-
})
377-
res := &msa.segs[int(id)]
378-
res.BindTo(msg)
379-
return res, 0, nil
379+
}
380+
msa.segs = append(msa.segs, newSeg)
381+
newSeg.BindTo(msg)
382+
return newSeg, 0, nil
380383
}
381384

382385
func (msa *MultiSegmentArena) String() string {

0 commit comments

Comments
 (0)