diff --git a/internal/compression_guard_middleware.go b/internal/compression_guard_handler.go similarity index 97% rename from internal/compression_guard_middleware.go rename to internal/compression_guard_handler.go index dd2758e..a091448 100644 --- a/internal/compression_guard_middleware.go +++ b/internal/compression_guard_handler.go @@ -9,7 +9,7 @@ import ( "github.com/klauspost/compress/gzhttp" ) -func NewCompressionGuardMiddleware(next http.Handler) http.Handler { +func NewCompressionGuardHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check for user-specific headers in the request if hasUserSpecificRequestHeaders(r) { diff --git a/internal/compression_guard_middleware_test.go b/internal/compression_guard_handler_test.go similarity index 94% rename from internal/compression_guard_middleware_test.go rename to internal/compression_guard_handler_test.go index 0c78cba..8731065 100644 --- a/internal/compression_guard_middleware_test.go +++ b/internal/compression_guard_handler_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCompressionGuardMiddleware(t *testing.T) { +func TestCompressionGuardHandler(t *testing.T) { tests := []struct { name string requestHeaders map[string]string @@ -85,7 +85,7 @@ func TestCompressionGuardMiddleware(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - handler := NewCompressionGuardMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := NewCompressionGuardHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for k, v := range tt.responseHeader { w.Header().Set(k, v) } diff --git a/internal/compression_handler.go b/internal/compression_handler.go new file mode 100644 index 0000000..e0dfb16 --- /dev/null +++ b/internal/compression_handler.go @@ -0,0 +1,37 @@ +package internal + +import ( + "net/http" + + "github.com/klauspost/compress/gzhttp" +) + +func NewCompressionHandler(jitter int, disableOnAuth bool, next http.Handler) http.Handler { + var wrapper func(http.Handler) http.HandlerFunc + var err error + + if jitter > 0 { + wrapper, err = gzhttp.NewWrapper( + gzhttp.MinSize(1024), + gzhttp.CompressionLevel(6), + gzhttp.RandomJitter(jitter, 0, false), + ) + } else { + wrapper, err = gzhttp.NewWrapper( + gzhttp.MinSize(1024), + gzhttp.CompressionLevel(6), + ) + } + + if err != nil { + panic("failed to create gzip wrapper: " + err.Error()) + } + + handler := wrapper(next) + + if disableOnAuth { + return NewCompressionGuardHandler(handler) + } + + return handler +} diff --git a/internal/compression_handler_test.go b/internal/compression_handler_test.go new file mode 100644 index 0000000..4717897 --- /dev/null +++ b/internal/compression_handler_test.go @@ -0,0 +1,88 @@ +package internal + +import ( + "compress/gzip" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCompressionHandler(t *testing.T) { + largeBody := strings.Repeat("A", 2000) + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, err := w.Write([]byte(largeBody)) + require.NoError(t, err) + }) + + t.Run("compresses responses", func(t *testing.T) { + handler := NewCompressionHandler(0, false, upstream) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, "gzip", rr.Header().Get("Content-Encoding")) + + reader, err := gzip.NewReader(rr.Body) + require.NoError(t, err) + defer reader.Close() + body, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, largeBody, string(body)) + }) + + t.Run("applies jitter when configured", func(t *testing.T) { + handler := NewCompressionHandler(32, false, upstream) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + require.Equal(t, "gzip", rr.Header().Get("Content-Encoding")) + + // Check for GZIP header with FCOMMENT flag (0x10) + bodyBytes := rr.Body.Bytes() + require.Greater(t, len(bodyBytes), 10) + hasComment := (bodyBytes[3] & 0x10) != 0 + assert.True(t, hasComment, "Expected FCOMMENT flag due to jitter") + }) + + t.Run("wraps with guard when disableOnAuth is true", func(t *testing.T) { + handler := NewCompressionHandler(0, true, upstream) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Cookie", "session=secret") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + // Should NOT be compressed due to Cookie header + assert.Empty(t, rr.Header().Get("Content-Encoding")) + assert.Equal(t, largeBody, rr.Body.String()) + }) + + t.Run("compresses authenticated requests when disableOnAuth is false", func(t *testing.T) { + handler := NewCompressionHandler(0, false, upstream) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Cookie", "session=secret") + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + assert.Equal(t, "gzip", rr.Header().Get("Content-Encoding")) + }) +} diff --git a/internal/handler.go b/internal/handler.go index df6b17e..1758b46 100644 --- a/internal/handler.go +++ b/internal/handler.go @@ -4,8 +4,6 @@ import ( "log/slog" "net/http" "net/url" - - "github.com/klauspost/compress/gzhttp" ) type HandlerOptions struct { @@ -26,38 +24,10 @@ func NewHandler(options HandlerOptions) http.Handler { handler := NewProxyHandler(options.targetUrl, options.badGatewayPage, options.forwardHeaders) handler = NewCacheHandler(options.cache, options.maxCacheableResponseBody, handler) handler = NewSendfileHandler(options.xSendfileEnabled, handler) - handler = NewRequestStartMiddleware(handler) + handler = NewRequestStartHandler(handler) if options.gzipCompressionEnabled { - var wrapper func(http.Handler) http.HandlerFunc - var err error - - if options.gzipCompressionJitter > 0 { - wrapper, err = gzhttp.NewWrapper( - gzhttp.MinSize(1024), - gzhttp.CompressionLevel(6), - gzhttp.RandomJitter(options.gzipCompressionJitter, 0, false), - ) - } else { - wrapper, err = gzhttp.NewWrapper( - gzhttp.MinSize(1024), - gzhttp.CompressionLevel(6), - ) - } - - if err != nil { - // If we cannot create the wrapper with the requested configuration (including jitter), - // we must fail hard rather than silently downgrading security or performance. - panic("failed to create gzip wrapper: " + err.Error()) - } - - gzipHandler := wrapper(handler) - - if options.gzipCompressionDisableOnAuth { - handler = NewCompressionGuardMiddleware(gzipHandler) - } else { - handler = gzipHandler - } + handler = NewCompressionHandler(options.gzipCompressionJitter, options.gzipCompressionDisableOnAuth, handler) } if options.maxRequestBody > 0 { @@ -65,7 +35,7 @@ func NewHandler(options HandlerOptions) http.Handler { } if options.logRequests { - handler = NewLoggingMiddleware(slog.Default(), handler) + handler = NewLoggingHandler(slog.Default(), handler) } return handler diff --git a/internal/logging_middleware.go b/internal/logging_handler.go similarity index 90% rename from internal/logging_middleware.go rename to internal/logging_handler.go index 7e82ea0..f06143c 100644 --- a/internal/logging_middleware.go +++ b/internal/logging_handler.go @@ -9,19 +9,19 @@ import ( "time" ) -type LoggingMiddleware struct { +type LoggingHandler struct { logger *slog.Logger next http.Handler } -func NewLoggingMiddleware(logger *slog.Logger, next http.Handler) *LoggingMiddleware { - return &LoggingMiddleware{ +func NewLoggingHandler(logger *slog.Logger, next http.Handler) *LoggingHandler { + return &LoggingHandler{ logger: logger, next: next, } } -func (h *LoggingMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *LoggingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { writer := newResponseWriter(w) started := time.Now() diff --git a/internal/logging_middleware_test.go b/internal/logging_handler_test.go similarity index 89% rename from internal/logging_middleware_test.go rename to internal/logging_handler_test.go index ebf201f..60dd7ec 100644 --- a/internal/logging_middleware_test.go +++ b/internal/logging_handler_test.go @@ -14,10 +14,10 @@ import ( "github.com/stretchr/testify/require" ) -func TestMiddleware_LoggingMiddleware(t *testing.T) { +func TestLoggingHandler(t *testing.T) { out := &strings.Builder{} logger := slog.New(slog.NewJSONHandler(out, nil)) - middleware := NewLoggingMiddleware(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := NewLoggingHandler(logger, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Cache", "miss") w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusCreated) @@ -29,7 +29,7 @@ func TestMiddleware_LoggingMiddleware(t *testing.T) { req.Header.Set("User-Agent", "Robot/1") req.Header.Set("Content-Type", "application/json") - middleware.ServeHTTP(httptest.NewRecorder(), req) + handler.ServeHTTP(httptest.NewRecorder(), req) logline := struct { Path string `json:"path"` diff --git a/internal/request_start_middleware.go b/internal/request_start_handler.go similarity index 82% rename from internal/request_start_middleware.go rename to internal/request_start_handler.go index c8621ab..3600df3 100644 --- a/internal/request_start_middleware.go +++ b/internal/request_start_handler.go @@ -6,7 +6,7 @@ import ( "time" ) -func NewRequestStartMiddleware(next http.Handler) http.Handler { +func NewRequestStartHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("X-Request-Start") == "" { timestamp := time.Now().UnixMilli() diff --git a/internal/request_start_middleware_test.go b/internal/request_start_handler_test.go similarity index 78% rename from internal/request_start_middleware_test.go rename to internal/request_start_handler_test.go index 63160d4..cab8dd7 100644 --- a/internal/request_start_middleware_test.go +++ b/internal/request_start_handler_test.go @@ -10,18 +10,18 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRequestStartMiddleware(t *testing.T) { +func TestRequestStartHandler(t *testing.T) { var capturedHeader string nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedHeader = r.Header.Get("X-Request-Start") }) - middleware := NewRequestStartMiddleware(nextHandler) + handler := NewRequestStartHandler(nextHandler) before := time.Now().UnixMilli() req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() - middleware.ServeHTTP(w, req) + handler.ServeHTTP(w, req) after := time.Now().UnixMilli() assert.NotEmpty(t, capturedHeader) @@ -34,19 +34,19 @@ func TestRequestStartMiddleware(t *testing.T) { assert.LessOrEqual(t, timestamp, after) } -func TestRequestStartMiddlewareDoesNotOverwriteExistingHeader(t *testing.T) { +func TestRequestStartHandlerDoesNotOverwriteExistingHeader(t *testing.T) { existingHeader := "t=1234567890" var capturedHeader string nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedHeader = r.Header.Get("X-Request-Start") }) - middleware := NewRequestStartMiddleware(nextHandler) + handler := NewRequestStartHandler(nextHandler) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-Request-Start", existingHeader) w := httptest.NewRecorder() - middleware.ServeHTTP(w, req) + handler.ServeHTTP(w, req) assert.Equal(t, existingHeader, capturedHeader) }