Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compress] Create compress writer only in case the content is compressible #928

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {

// If the encoder supports Resetting (IoReseterWriter), then it can be pooled.
encoder := fn(ioutil.Discard, c.level)
// Release resources of the created encoder
if enc, ok := encoder.(io.WriteCloser); ok {
defer enc.Close()
}
if encoder != nil {
if _, ok := encoder.(ioResetterWriter); ok {
pool := &sync.Pool{
Expand Down Expand Up @@ -189,29 +193,26 @@ func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
// current Compressor.
func (c *Compressor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
encoder, encoding := c.selectEncoder(r.Header, w)

cw := &compressResponseWriter{
ResponseWriter: w,
w: w,
contentTypes: c.allowedTypes,
contentWildcards: c.allowedWildcards,
encoding: encoding,
compressible: false, // determined in post-handler
}
if encoder != nil {
cw.w = encoder
cw.encoder = encoder
}
// Re-add the encoder to the pool if applicable.
defer cleanup()
defer cw.Close()

next.ServeHTTP(cw, r)
})
}

// selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (func() io.Writer, string) {
header := h.Get("Accept-Encoding")

// Parse the names of all accepted algorithms from the header.
Expand All @@ -221,23 +222,31 @@ func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, strin
for _, name := range c.encodingPrecedence {
if matchAcceptEncoding(accepted, name) {
if pool, ok := c.pooledEncoders[name]; ok {
encoder := pool.Get().(ioResetterWriter)
cleanup := func() {
pool.Put(encoder)
fn := func() io.Writer {
enc := pool.Get().(ioResetterWriter)
enc.Reset(w)
return &pooledEncoder{
Writer: enc,
pool: pool,
}
}
encoder.Reset(w)
return encoder, name, cleanup
return fn, name

}
if fn, ok := c.encoders[name]; ok {
return fn(w, c.level), name, func() {}
fn := func() io.Writer {
return &encoder{
Writer: fn(w, c.level),
}
}
return fn, name
}
}

}

// No encoder found to match the accepted encoding
return nil, "", func() {}
return nil, ""
}

func matchAcceptEncoding(accepted []string, encoding string) bool {
Expand Down Expand Up @@ -272,6 +281,8 @@ type compressResponseWriter struct {
encoding string
wroteHeader bool
compressible bool

encoder func() io.Writer
}

func (cw *compressResponseWriter) isCompressible() bool {
Expand Down Expand Up @@ -331,6 +342,9 @@ func (cw *compressResponseWriter) Write(p []byte) (int, error) {

func (cw *compressResponseWriter) writer() io.Writer {
if cw.compressible {
if cw.w == nil {
cw.w = cw.encoder()
}
return cw.w
}
return cw.ResponseWriter
Expand Down Expand Up @@ -381,6 +395,33 @@ func (cw *compressResponseWriter) Unwrap() http.ResponseWriter {
return cw.ResponseWriter
}

type (
encoder struct {
io.Writer
}

pooledEncoder struct {
io.Writer
pool *sync.Pool
}
)

func (e *encoder) Close() error {
if c, ok := e.Writer.(io.WriteCloser); ok {
return c.Close()
}
return nil
}

func (e *pooledEncoder) Close() error {
var err error
if w, ok := e.Writer.(io.WriteCloser); ok {
err = w.Close()
}
e.pool.Put(e.Writer)
return err
}

func encoderGzip(w io.Writer, level int) io.Writer {
gw, err := gzip.NewWriterLevel(w, level)
if err != nil {
Expand Down
95 changes: 91 additions & 4 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"bytes"
"compress/flate"
"compress/gzip"
"fmt"
Expand All @@ -26,8 +27,13 @@ func TestCompressor(t *testing.T) {
return w
})

if len(compressor.encoders) != 1 {
t.Errorf("nop encoder should be stored in the encoders map")
var sideEffect int
compressor.SetEncoder("test", func(w io.Writer, _ int) io.Writer {
return newSideEffectWriter(w, &sideEffect)
})

if len(compressor.encoders) != 2 {
t.Errorf("nop and test encoders should be stored in the encoders map")
}

r.Use(compressor.Handler)
Expand All @@ -43,7 +49,12 @@ func TestCompressor(t *testing.T) {
})

r.Get("/getplain", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte("textstring"))
})

r.Get("/getimage", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/png")
w.Write([]byte("textstring"))
})

Expand All @@ -55,18 +66,21 @@ func TestCompressor(t *testing.T) {
path string
expectedEncoding string
acceptedEncodings []string
checkRawResponse bool
}{
{
name: "no expected encodings due to no accepted encodings",
path: "/gethtml",
acceptedEncodings: nil,
acceptedEncodings: []string{""},
expectedEncoding: "",
checkRawResponse: true,
},
{
name: "no expected encodings due to content type",
path: "/getplain",
acceptedEncodings: nil,
expectedEncoding: "",
checkRawResponse: true,
},
{
name: "gzip is only encoding",
Expand All @@ -92,12 +106,30 @@ func TestCompressor(t *testing.T) {
path: "/getcss",
acceptedEncodings: []string{"nop, gzip, deflate"},
expectedEncoding: "nop",
checkRawResponse: true,
},
{
name: "test encoder is used",
path: "/getimage",
acceptedEncodings: []string{"test"},
expectedEncoding: "",
checkRawResponse: true,
},
{
name: "test encoder is used and Close is called",
path: "/gethtml",
acceptedEncodings: []string{"test"},
expectedEncoding: "test",
checkRawResponse: true,
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
if tc.checkRawResponse {
testRequestRawResponse(t, ts, "GET", tc.path, []byte("textstring"), tc.acceptedEncodings...)
}
resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...)
if respString != "textstring" {
t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString)
Expand All @@ -109,6 +141,9 @@ func TestCompressor(t *testing.T) {
})

}
if sideEffect != 0 {
t.Errorf("side effect should be cleared after close")
}
}

func TestCompressorWildcards(t *testing.T) {
Expand Down Expand Up @@ -171,6 +206,35 @@ func TestCompressorWildcards(t *testing.T) {
}
}

func testRequestRawResponse(t *testing.T, ts *httptest.Server, method, path string, exp []byte, encodings ...string) {
req, err := http.NewRequest(method, ts.URL+path, nil)
if err != nil {
t.Fatal(err)
return
}
if len(encodings) > 0 {
encodingsString := strings.Join(encodings, ",")
req.Header.Set("Accept-Encoding", encodingsString)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return
}

respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return
}
defer resp.Body.Close()

if !bytes.Equal(respBody, exp) {
t.Errorf("expected %q but got %q", exp, respBody)
}
}

func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, nil)
if err != nil {
Expand Down Expand Up @@ -217,3 +281,26 @@ func decodeResponseBody(t *testing.T, resp *http.Response) string {

return string(respBody)
}

type (
sideEffectWriter struct {
w io.Writer
s *int
}
)

func newSideEffectWriter(w io.Writer, sideEffect *int) io.Writer {
*sideEffect = *sideEffect + 1

return &sideEffectWriter{w: w, s: sideEffect}
}

func (w *sideEffectWriter) Write(p []byte) (n int, err error) {
return w.w.Write(p)
}

func (w *sideEffectWriter) Close() error {
*w.s = *w.s - 1

return nil
}