@@ -32,7 +32,6 @@ import (
3232 "io"
3333 "strconv"
3434 "strings"
35- "sync"
3635
3736 kgzip "github.com/klauspost/compress/gzip"
3837 "github.com/klauspost/compress/zstd"
@@ -138,42 +137,29 @@ func getZstdLevelOrDefault(level int) zstd.EncoderLevel {
138137 return zstd .SpeedDefault
139138}
140139
141- type sema struct {
142- c chan struct {}
140+ type compressionPipelineBuilder struct {
141+ zstdEncoders map [zstd. EncoderLevel ] * sharedZstdEncoder
143142}
144143
145- func (s * sema ) Lock () { s .c <- struct {}{} }
146- func (s * sema ) Unlock () { <- s .c }
147-
148- var (
149- // compressionMux protects zstdEncoder. It must be locked
150- // when doing compression that _might_ use zstdEncoder.
151- //
152- // It's a channel-based semaphore rather than a mutex
153- // so that contention on it doesn't appear in the mutex
154- // profile (we expect it to be contended).
155- // This is a kludge. We only really want it for zstdEncoder,
156- // but the places where we actually do compression just
157- // take a compressor interface. It's easier for now to just
158- // have this global semaphore than to plumb the locking
159- // through the existing abstractions.
160- compressionMux = & sema {c : make (chan struct {}, 1 )}
161-
162- zstdEncoderOnce sync.Once
163- zstdEncoder * zstd.Encoder
164- zstdEncoderErr error
165- )
166-
167- func getZstdEncoder (opts ... zstd.EOption ) (* zstd.Encoder , error ) {
168- zstdEncoderOnce .Do (func () {
169- zstdEncoder , zstdEncoderErr = zstd .NewWriter (nil , opts ... )
170- })
171- return zstdEncoder , zstdEncoderErr
144+ func (b * compressionPipelineBuilder ) getZstdEncoder (level zstd.EncoderLevel ) (* sharedZstdEncoder , error ) {
145+ if b .zstdEncoders == nil {
146+ b .zstdEncoders = make (map [zstd.EncoderLevel ]* sharedZstdEncoder )
147+ }
148+ encoder , ok := b .zstdEncoders [level ]
149+ if ! ok {
150+ var err error
151+ encoder , err = newSharedZstdEncoder (level )
152+ if err != nil {
153+ return nil , err
154+ }
155+ b .zstdEncoders [level ] = encoder
156+ }
157+ return encoder , nil
172158}
173159
174- // newCompressionPipeline returns a compressor that converts the data written to
175- // it from the expected input compression to the given output compression.
176- func newCompressionPipeline (in compression , out compression ) (compressor , error ) {
160+ // Build returns a compressor that converts the data written to it from the
161+ // expected input compression to the given output compression.
162+ func ( b * compressionPipelineBuilder ) Build (in compression , out compression ) (compressor , error ) {
177163 if in == out {
178164 return newPassthroughCompressor (), nil
179165 }
@@ -183,11 +169,15 @@ func newCompressionPipeline(in compression, out compression) (compressor, error)
183169 }
184170
185171 if in == noCompression && out .algorithm == compressionAlgorithmZstd {
186- return getZstdEncoder (zstd . WithEncoderLevel ( getZstdLevelOrDefault (out .level ) ))
172+ return b . getZstdEncoder (getZstdLevelOrDefault (out .level ))
187173 }
188174
189175 if in .algorithm == compressionAlgorithmGzip && out .algorithm == compressionAlgorithmZstd {
190- return newZstdRecompressor (getZstdLevelOrDefault (out .level ))
176+ encoder , err := b .getZstdEncoder (getZstdLevelOrDefault (out .level ))
177+ if err != nil {
178+ return nil , err
179+ }
180+ return newZstdRecompressor (encoder ), nil
191181 }
192182
193183 return nil , fmt .Errorf ("unsupported recompression: %s -> %s" , in , out )
@@ -198,8 +188,11 @@ func newCompressionPipeline(in compression, out compression) (compressor, error)
198188// the data from one format and then re-compresses it into another format.
199189type compressor interface {
200190 io.Writer
201- io.Closer
191+ // Reset reset the compressor to the given writer. It may also acquire a
192+ // shared underlaying resource, so callers must always call Close().
202193 Reset (w io.Writer )
194+ // Close closes the compressor and releases any shared underlaying resource.
195+ Close () error
203196}
204197
205198// newPassthroughCompressor returns a compressor that simply passes all data
@@ -220,20 +213,16 @@ func (r *passthroughCompressor) Close() error {
220213 return nil
221214}
222215
223- func newZstdRecompressor (level zstd.EncoderLevel ) (* zstdRecompressor , error ) {
224- zstdOut , err := getZstdEncoder (zstd .WithEncoderLevel (level ))
225- if err != nil {
226- return nil , err
227- }
228- return & zstdRecompressor {zstdOut : zstdOut , err : make (chan error )}, nil
216+ func newZstdRecompressor (encoder * sharedZstdEncoder ) * zstdRecompressor {
217+ return & zstdRecompressor {zstdOut : encoder , err : make (chan error )}
229218}
230219
231220type zstdRecompressor struct {
232221 // err synchronizes finishing writes after closing pw and reports any
233222 // error during recompression
234223 err chan error
235224 pw io.WriteCloser
236- zstdOut * zstd. Encoder
225+ zstdOut * sharedZstdEncoder
237226}
238227
239228func (r * zstdRecompressor ) Reset (w io.Writer ) {
@@ -260,3 +249,36 @@ func (r *zstdRecompressor) Close() error {
260249 err := <- r .err
261250 return cmp .Or (err , r .zstdOut .Close ())
262251}
252+
253+ // newSharedZstdEncoder creates a new shared Zstd encoder with the given level.
254+ // It expects the Reset and Close method to be used in an acquire and release
255+ // fashion.
256+ func newSharedZstdEncoder (level zstd.EncoderLevel ) (* sharedZstdEncoder , error ) {
257+ encoder , err := zstd .NewWriter (nil , zstd .WithEncoderLevel (level ))
258+ if err != nil {
259+ return nil , err
260+ }
261+ return & sharedZstdEncoder {encoder : encoder , sema : make (chan struct {}, 1 )}, nil
262+ }
263+
264+ type sharedZstdEncoder struct {
265+ encoder * zstd.Encoder
266+ sema chan struct {}
267+ }
268+
269+ // Reset acquires the semaphore and resets the encoder to the given writer.
270+ func (s * sharedZstdEncoder ) Reset (w io.Writer ) {
271+ s .sema <- struct {}{}
272+ s .encoder .Reset (w )
273+ }
274+
275+ func (s * sharedZstdEncoder ) Write (p []byte ) (int , error ) {
276+ return s .encoder .Write (p )
277+ }
278+
279+ // Close releases the semaphore and closes the encoder.
280+ func (s * sharedZstdEncoder ) Close () error {
281+ err := s .encoder .Close ()
282+ <- s .sema
283+ return err
284+ }
0 commit comments