Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/klauspost/compress/gzhttp"
)

func NewCompressionGuardMiddleware(next http.Handler) http.Handler {
func NewCompressionGuardHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check for user-specific headers in the request
if hasUserSpecificRequestHeaders(r) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
)

func TestCompressionGuardMiddleware(t *testing.T) {
func TestCompressionGuardHandler(t *testing.T) {
tests := []struct {
name string
requestHeaders map[string]string
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestCompressionGuardMiddleware(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := NewCompressionGuardMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := NewCompressionGuardHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for k, v := range tt.responseHeader {
w.Header().Set(k, v)
}
Expand Down
37 changes: 37 additions & 0 deletions internal/compression_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package internal

import (
"net/http"

"github.com/klauspost/compress/gzhttp"
)

func NewCompressionHandler(jitter int, disableOnAuth bool, next http.Handler) http.Handler {
var wrapper func(http.Handler) http.HandlerFunc
var err error

if jitter > 0 {
wrapper, err = gzhttp.NewWrapper(
gzhttp.MinSize(1024),
gzhttp.CompressionLevel(6),
gzhttp.RandomJitter(jitter, 0, false),
)
} else {
wrapper, err = gzhttp.NewWrapper(
gzhttp.MinSize(1024),
gzhttp.CompressionLevel(6),
)
}

if err != nil {
panic("failed to create gzip wrapper: " + err.Error())
}

handler := wrapper(next)

if disableOnAuth {
return NewCompressionGuardHandler(handler)
}

return handler
}
88 changes: 88 additions & 0 deletions internal/compression_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package internal

import (
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCompressionHandler(t *testing.T) {
largeBody := strings.Repeat("A", 2000)

upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, err := w.Write([]byte(largeBody))
require.NoError(t, err)
})

t.Run("compresses responses", func(t *testing.T) {
handler := NewCompressionHandler(0, false, upstream)

req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

assert.Equal(t, "gzip", rr.Header().Get("Content-Encoding"))

reader, err := gzip.NewReader(rr.Body)
require.NoError(t, err)
defer reader.Close()
body, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, largeBody, string(body))
})

t.Run("applies jitter when configured", func(t *testing.T) {
handler := NewCompressionHandler(32, false, upstream)

req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

require.Equal(t, "gzip", rr.Header().Get("Content-Encoding"))

// Check for GZIP header with FCOMMENT flag (0x10)
bodyBytes := rr.Body.Bytes()
require.Greater(t, len(bodyBytes), 10)
hasComment := (bodyBytes[3] & 0x10) != 0
assert.True(t, hasComment, "Expected FCOMMENT flag due to jitter")
})

t.Run("wraps with guard when disableOnAuth is true", func(t *testing.T) {
handler := NewCompressionHandler(0, true, upstream)

req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
req.Header.Set("Cookie", "session=secret")
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

// Should NOT be compressed due to Cookie header
assert.Empty(t, rr.Header().Get("Content-Encoding"))
assert.Equal(t, largeBody, rr.Body.String())
})

t.Run("compresses authenticated requests when disableOnAuth is false", func(t *testing.T) {
handler := NewCompressionHandler(0, false, upstream)

req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Accept-Encoding", "gzip")
req.Header.Set("Cookie", "session=secret")
rr := httptest.NewRecorder()

handler.ServeHTTP(rr, req)

assert.Equal(t, "gzip", rr.Header().Get("Content-Encoding"))
})
}
36 changes: 3 additions & 33 deletions internal/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"log/slog"
"net/http"
"net/url"

"github.com/klauspost/compress/gzhttp"
)

type HandlerOptions struct {
Expand All @@ -26,46 +24,18 @@ func NewHandler(options HandlerOptions) http.Handler {
handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders)
handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler)
handler = NewSendfileHandler(options.xSendfileEnabled, handler)
handler = NewRequestStartMiddleware(handler)
handler = NewRequestStartHandler(handler)

if options.gzipCompressionEnabled {
var wrapper func(http.Handler) http.HandlerFunc
var err error

if options.gzipCompressionJitter > 0 {
wrapper, err = gzhttp.NewWrapper(
gzhttp.MinSize(1024),
gzhttp.CompressionLevel(6),
gzhttp.RandomJitter(options.gzipCompressionJitter, 0, false),
)
} else {
wrapper, err = gzhttp.NewWrapper(
gzhttp.MinSize(1024),
gzhttp.CompressionLevel(6),
)
}

if err != nil {
// If we cannot create the wrapper with the requested configuration (including jitter),
// we must fail hard rather than silently downgrading security or performance.
panic("failed to create gzip wrapper: " + err.Error())
}

gzipHandler := wrapper(handler)

if options.gzipCompressionDisableOnAuth {
handler = NewCompressionGuardMiddleware(gzipHandler)
} else {
handler = gzipHandler
}
handler = NewCompressionHandler(options.gzipCompressionJitter, options.gzipCompressionDisableOnAuth, handler)
}

if options.maxRequestBody > 0 {
handler = http.MaxBytesHandler(handler, int64(options.maxRequestBody))
}

if options.logRequests {
handler = NewLoggingMiddleware(slog.Default(), handler)
handler = NewLoggingHandler(slog.Default(), handler)
}

return handler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ import (
"time"
)

type LoggingMiddleware struct {
type LoggingHandler struct {
logger *slog.Logger
next http.Handler
}

func NewLoggingMiddleware(logger *slog.Logger, next http.Handler) *LoggingMiddleware {
return &LoggingMiddleware{
func NewLoggingHandler(logger *slog.Logger, next http.Handler) *LoggingHandler {
return &LoggingHandler{
logger: logger,
next: next,
}
}

func (h *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *LoggingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writer := newResponseWriter(w)

started := time.Now()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
"github.com/stretchr/testify/require"
)

func TestMiddleware_LoggingMiddleware(t *testing.T) {
func TestLoggingHandler(t *testing.T) {
out := &strings.Builder{}
logger := slog.New(slog.NewJSONHandler(out, nil))
middleware := NewLoggingMiddleware(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := NewLoggingHandler(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Cache", "miss")
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusCreated)
Expand All @@ -29,7 +29,7 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) {
req.Header.Set("User-Agent", "Robot/1")
req.Header.Set("Content-Type", "application/json")

middleware.ServeHTTP(httptest.NewRecorder(), req)
handler.ServeHTTP(httptest.NewRecorder(), req)

logline := struct {
Path string `json:"path"`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

func NewRequestStartMiddleware(next http.Handler) http.Handler {
func NewRequestStartHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Request-Start") == "" {
timestamp := time.Now().UnixMilli()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ import (
"github.com/stretchr/testify/assert"
)

func TestRequestStartMiddleware(t *testing.T) {
func TestRequestStartHandler(t *testing.T) {
var capturedHeader string
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeader = r.Header.Get("X-Request-Start")
})

middleware := NewRequestStartMiddleware(nextHandler)
handler := NewRequestStartHandler(nextHandler)

before := time.Now().UnixMilli()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)
handler.ServeHTTP(w, req)
after := time.Now().UnixMilli()

assert.NotEmpty(t, capturedHeader)
Expand All @@ -34,19 +34,19 @@ func TestRequestStartMiddleware(t *testing.T) {
assert.LessOrEqual(t, timestamp, after)
}

func TestRequestStartMiddlewareDoesNotOverwriteExistingHeader(t *testing.T) {
func TestRequestStartHandlerDoesNotOverwriteExistingHeader(t *testing.T) {
existingHeader := "t=1234567890"
var capturedHeader string
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeader = r.Header.Get("X-Request-Start")
})

middleware := NewRequestStartMiddleware(nextHandler)
handler := NewRequestStartHandler(nextHandler)

req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Request-Start", existingHeader)
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)
handler.ServeHTTP(w, req)

assert.Equal(t, existingHeader, capturedHeader)
}