Skip to content

Commit 4c9cf96

Browse files
committed
refactor(profiler): alternative zstd encoder reuse approach
1 parent 58dd4c4 commit 4c9cf96

File tree

4 files changed

+78
-58
lines changed

4 files changed

+78
-58
lines changed

profiler/compression.go

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
"io"
3333
"strconv"
3434
"strings"
35-
"sync"
3635

3736
kgzip "github.com/klauspost/compress/gzip"
3837
"github.com/klauspost/compress/zstd"
@@ -138,42 +137,29 @@ func getZstdLevelOrDefault(level int) zstd.EncoderLevel {
138137
return zstd.SpeedDefault
139138
}
140139

141-
type sema struct {
142-
c chan struct{}
140+
type compressionPipelineBuilder struct {
141+
zstdEncoders map[zstd.EncoderLevel]*sharedZstdEncoder
143142
}
144143

145-
func (s *sema) Lock() { s.c <- struct{}{} }
146-
func (s *sema) Unlock() { <-s.c }
147-
148-
var (
149-
// compressionMux protects zstdEncoder. It must be locked
150-
// when doing compression that _might_ use zstdEncoder.
151-
//
152-
// It's a channel-based semaphore rather than a mutex
153-
// so that contention on it doesn't appear in the mutex
154-
// profile (we expect it to be contended).
155-
// This is a kludge. We only really want it for zstdEncoder,
156-
// but the places where we actually do compression just
157-
// take a compressor interface. It's easier for now to just
158-
// have this global semaphore than to plumb the locking
159-
// through the existing abstractions.
160-
compressionMux = &sema{c: make(chan struct{}, 1)}
161-
162-
zstdEncoderOnce sync.Once
163-
zstdEncoder *zstd.Encoder
164-
zstdEncoderErr error
165-
)
166-
167-
func getZstdEncoder(opts ...zstd.EOption) (*zstd.Encoder, error) {
168-
zstdEncoderOnce.Do(func() {
169-
zstdEncoder, zstdEncoderErr = zstd.NewWriter(nil, opts...)
170-
})
171-
return zstdEncoder, zstdEncoderErr
144+
func (b *compressionPipelineBuilder) getZstdEncoder(level zstd.EncoderLevel) (*sharedZstdEncoder, error) {
145+
if b.zstdEncoders == nil {
146+
b.zstdEncoders = make(map[zstd.EncoderLevel]*sharedZstdEncoder)
147+
}
148+
encoder, ok := b.zstdEncoders[level]
149+
if !ok {
150+
var err error
151+
encoder, err = newSharedZstdEncoder(level)
152+
if err != nil {
153+
return nil, err
154+
}
155+
b.zstdEncoders[level] = encoder
156+
}
157+
return encoder, nil
172158
}
173159

174-
// newCompressionPipeline returns a compressor that converts the data written to
175-
// it from the expected input compression to the given output compression.
176-
func newCompressionPipeline(in compression, out compression) (compressor, error) {
160+
// Build returns a compressor that converts the data written to it from the
161+
// expected input compression to the given output compression.
162+
func (b *compressionPipelineBuilder) Build(in compression, out compression) (compressor, error) {
177163
if in == out {
178164
return newPassthroughCompressor(), nil
179165
}
@@ -183,11 +169,15 @@ func newCompressionPipeline(in compression, out compression) (compressor, error)
183169
}
184170

185171
if in == noCompression && out.algorithm == compressionAlgorithmZstd {
186-
return getZstdEncoder(zstd.WithEncoderLevel(getZstdLevelOrDefault(out.level)))
172+
return b.getZstdEncoder(getZstdLevelOrDefault(out.level))
187173
}
188174

189175
if in.algorithm == compressionAlgorithmGzip && out.algorithm == compressionAlgorithmZstd {
190-
return newZstdRecompressor(getZstdLevelOrDefault(out.level))
176+
encoder, err := b.getZstdEncoder(getZstdLevelOrDefault(out.level))
177+
if err != nil {
178+
return nil, err
179+
}
180+
return newZstdRecompressor(encoder), nil
191181
}
192182

193183
return nil, fmt.Errorf("unsupported recompression: %s -> %s", in, out)
@@ -198,8 +188,11 @@ func newCompressionPipeline(in compression, out compression) (compressor, error)
198188
// the data from one format and then re-compresses it into another format.
199189
type compressor interface {
200190
io.Writer
201-
io.Closer
191+
// Reset reset the compressor to the given writer. It may also acquire a
192+
// shared underlaying resource, so callers must always call Close().
202193
Reset(w io.Writer)
194+
// Close closes the compressor and releases any shared underlaying resource.
195+
Close() error
203196
}
204197

205198
// newPassthroughCompressor returns a compressor that simply passes all data
@@ -220,20 +213,16 @@ func (r *passthroughCompressor) Close() error {
220213
return nil
221214
}
222215

223-
func newZstdRecompressor(level zstd.EncoderLevel) (*zstdRecompressor, error) {
224-
zstdOut, err := getZstdEncoder(zstd.WithEncoderLevel(level))
225-
if err != nil {
226-
return nil, err
227-
}
228-
return &zstdRecompressor{zstdOut: zstdOut, err: make(chan error)}, nil
216+
func newZstdRecompressor(encoder *sharedZstdEncoder) *zstdRecompressor {
217+
return &zstdRecompressor{zstdOut: encoder, err: make(chan error)}
229218
}
230219

231220
type zstdRecompressor struct {
232221
// err synchronizes finishing writes after closing pw and reports any
233222
// error during recompression
234223
err chan error
235224
pw io.WriteCloser
236-
zstdOut *zstd.Encoder
225+
zstdOut *sharedZstdEncoder
237226
}
238227

239228
func (r *zstdRecompressor) Reset(w io.Writer) {
@@ -260,3 +249,36 @@ func (r *zstdRecompressor) Close() error {
260249
err := <-r.err
261250
return cmp.Or(err, r.zstdOut.Close())
262251
}
252+
253+
// newSharedZstdEncoder creates a new shared Zstd encoder with the given level.
254+
// It expects the Reset and Close method to be used in an acquire and release
255+
// fashion.
256+
func newSharedZstdEncoder(level zstd.EncoderLevel) (*sharedZstdEncoder, error) {
257+
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
258+
if err != nil {
259+
return nil, err
260+
}
261+
return &sharedZstdEncoder{encoder: encoder, sema: make(chan struct{}, 1)}, nil
262+
}
263+
264+
type sharedZstdEncoder struct {
265+
encoder *zstd.Encoder
266+
sema chan struct{}
267+
}
268+
269+
// Reset acquires the semaphore and resets the encoder to the given writer.
270+
func (s *sharedZstdEncoder) Reset(w io.Writer) {
271+
s.sema <- struct{}{}
272+
s.encoder.Reset(w)
273+
}
274+
275+
func (s *sharedZstdEncoder) Write(p []byte) (int, error) {
276+
return s.encoder.Write(p)
277+
}
278+
279+
// Close releases the semaphore and closes the encoder.
280+
func (s *sharedZstdEncoder) Close() error {
281+
err := s.encoder.Close()
282+
<-s.sema
283+
return err
284+
}

profiler/compression_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ func TestNewCompressionPipeline(t *testing.T) {
4343

4444
for _, test := range tests {
4545
t.Run(fmt.Sprintf("%s->%s", test.in, test.out), func(t *testing.T) {
46-
pipeline, err := newCompressionPipeline(test.in, test.out)
46+
var pipelineBuilder compressionPipelineBuilder
47+
pipeline, err := pipelineBuilder.Build(test.in, test.out)
4748
require.NoError(t, err)
4849
buf := &bytes.Buffer{}
4950
pipeline.Reset(buf)
@@ -172,11 +173,13 @@ func BenchmarkRecompression(b *testing.B) {
172173
b.Run(fmt.Sprintf("%s-%s", in.inAlg.String(), in.outLevel), func(b *testing.B) {
173174
data := compressData(b, inputdata, in.inAlg)
174175
b.ResetTimer()
176+
var pipelineBuilder compressionPipelineBuilder
175177
for i := 0; i < b.N; i++ {
176-
z, err := newZstdRecompressor(in.outLevel)
178+
encoder, err := pipelineBuilder.getZstdEncoder(in.outLevel)
177179
if err != nil {
178180
b.Fatal(err)
179181
}
182+
z := newZstdRecompressor(encoder)
180183
z.Reset(io.Discard)
181184
if _, err := z.Write(data); err != nil {
182185
b.Fatal(err)

profiler/profile.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,8 @@ var profileTypes = map[ProfileType]profileType{
114114
p.stopCPUProfile()
115115

116116
c := p.compressors[CPUProfile]
117-
compressionMux.Lock()
118-
defer compressionMux.Unlock()
119117
c.Reset(&buf)
118+
defer c.Close()
120119
if _, err := outBuf.WriteTo(c); err != nil {
121120
return nil, err
122121
}
@@ -183,9 +182,8 @@ var profileTypes = map[ProfileType]profileType{
183182
}
184183

185184
c := p.compressors[expGoroutineWaitProfile]
186-
compressionMux.Lock()
187-
defer compressionMux.Unlock()
188185
c.Reset(pprof)
186+
defer c.Close()
189187
err := goroutineDebug2ToPprof(text, c, now)
190188
err = cmp.Or(err, c.Close())
191189
return pprof.Bytes(), err
@@ -197,9 +195,8 @@ var profileTypes = map[ProfileType]profileType{
197195
Collect: func(p *profiler) ([]byte, error) {
198196
var buf bytes.Buffer
199197
c := p.compressors[MetricsProfile]
200-
compressionMux.Lock()
201-
defer compressionMux.Unlock()
202198
c.Reset(&buf)
199+
defer c.Close()
203200
interrupted := p.interruptibleSleep(p.cfg.period)
204201
err := p.met.report(now(), c)
205202
err = cmp.Or(err, c.Close())
@@ -229,9 +226,8 @@ var profileTypes = map[ProfileType]profileType{
229226
trace.Stop()
230227

231228
c := p.compressors[executionTrace]
232-
compressionMux.Lock()
233-
defer compressionMux.Unlock()
234229
c.Reset(buf)
230+
defer c.Close()
235231
if _, err := outBuf.WriteTo(c); err != nil {
236232
return nil, err
237233
}
@@ -303,9 +299,8 @@ func collectGenericProfile(name string, pt ProfileType) func(p *profiler) ([]byt
303299
dp, ok := p.deltas[pt]
304300
if !ok || !p.cfg.deltaProfiles {
305301
c := p.compressors[pt]
306-
compressionMux.Lock()
307-
defer compressionMux.Unlock()
308302
c.Reset(&buf)
303+
defer c.Close()
309304
err := p.lookupProfile(name, c, 0)
310305
err = cmp.Or(err, c.Close())
311306
return buf.Bytes(), err
@@ -456,9 +451,8 @@ func (fdp *fastDeltaProfiler) Delta(data []byte) (b []byte, err error) {
456451

457452
fdp.buf.Reset()
458453
c := fdp.compressor
459-
compressionMux.Lock()
460-
defer compressionMux.Unlock()
461454
c.Reset(&fdp.buf)
455+
defer c.Close()
462456

463457
if err = fdp.dc.Delta(data, c); err != nil {
464458
return nil, fmt.Errorf("error computing delta: %s", err.Error())

profiler/profiler.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,11 @@ func newProfiler(opts ...Option) (*profiler, error) {
259259
if p.cfg.traceConfig.Enabled {
260260
types = append(types, executionTrace)
261261
}
262+
var pipelineBuilder compressionPipelineBuilder
262263
for _, pt := range types {
263264
isDelta := p.cfg.deltaProfiles && len(profileTypes[pt].DeltaValues) > 0
264265
in, out := compressionStrategy(pt, isDelta, p.cfg.compressionConfig)
265-
compressor, err := newCompressionPipeline(in, out)
266+
compressor, err := pipelineBuilder.Build(in, out)
266267
if err != nil {
267268
return nil, err
268269
}

0 commit comments

Comments
 (0)