From b1c30adb434ed620baaba918b5d9ed34ce7e12c7 Mon Sep 17 00:00:00 2001 From: Cody A Price Date: Wed, 14 Aug 2024 12:35:42 -0500 Subject: [PATCH 1/3] adds TrustedOriginPredicateFunc option --- csrf.go | 29 +++++++++++++++++------------ doc.go | 5 ++--- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/csrf.go b/csrf.go index 97a3925..26e0a64 100644 --- a/csrf.go +++ b/csrf.go @@ -85,14 +85,15 @@ type options struct { Path string // Note that the function and field names match the case of the associated // http.Cookie field instead of the "correct" HTTPOnly name that golint suggests. - HttpOnly bool - Secure bool - SameSite SameSiteMode - RequestHeader string - FieldName string - ErrorHandler http.Handler - CookieName string - TrustedOrigins []string + HttpOnly bool + Secure bool + SameSite SameSiteMode + RequestHeader string + FieldName string + ErrorHandler http.Handler + CookieName string + TrustedOrigins []string + TrustedOriginPredicateFunc func(referer string) bool } // Protect is HTTP middleware that provides Cross-Site Request Forgery @@ -258,10 +259,14 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { valid := sameOrigin(r.URL, referer) if !valid { - for _, trustedOrigin := range cs.opts.TrustedOrigins { - if referer.Host == trustedOrigin { - valid = true - break + if cs.opts.TrustedOriginPredicateFunc != nil { + valid = cs.opts.TrustedOriginPredicateFunc(referer.Host) + } else { + for _, trustedOrigin := range cs.opts.TrustedOrigins { + if referer.Host == trustedOrigin { + valid = true + break + } } } } diff --git a/doc.go b/doc.go index 503c948..f12be51 100644 --- a/doc.go +++ b/doc.go @@ -17,8 +17,8 @@ field. gorilla/csrf is easy to use: add the middleware to individual handlers with the below: - CSRF := csrf.Protect([]byte("32-byte-long-auth-key")) - http.HandlerFunc("/route", CSRF(YourHandler)) + CSRF := csrf.Protect([]byte("32-byte-long-auth-key")) + http.HandlerFunc("/route", CSRF(YourHandler)) ... and then collect the token with `csrf.Token(r)` before passing it to the template, JSON body or HTTP header (you pick!). gorilla/csrf inspects the form body @@ -171,6 +171,5 @@ important. and the one-time-pad used for masking them. This library does not seek to be adventurous. - */ package csrf From 3e92a5ddfeee518f6a73076f6ec18001981a2dbb Mon Sep 17 00:00:00 2001 From: Cody A Price Date: Wed, 14 Aug 2024 12:55:27 -0500 Subject: [PATCH 2/3] add TrustedOriginPredicateFunc method to options --- options.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/options.go b/options.go index c61d301..1d70e68 100644 --- a/options.go +++ b/options.go @@ -131,6 +131,21 @@ func TrustedOrigins(origins []string) Option { } } +// TrustedOriginPredicateFunc configures a predicate function that can be used +// to determine if a given Referer is trusted. +// Like TrustedOrigins, this will allow cross-domain CSRF use-cases - e.g. where +// the front-end is served from a different domain than the API server - to +// correctly pass a CSRF check. +// However, this function allows for more complex logic to be applied to determine +// if a Referer is trusted than strict equality string matching. +// +// You should only pass origins you own or have full control over. +func TrustedOriginPredicateFunc(f func(referer string) bool) Option { + return func(cs *csrf) { + cs.opts.TrustedOriginPredicateFunc = f + } +} + // setStore sets the store used by the CSRF middleware. // Note: this is private (for now) to allow for internal API changes. func setStore(s store) Option { From 9baf06986113ff3672ee506ace1d604423f765d1 Mon Sep 17 00:00:00 2001 From: Cody A Price Date: Wed, 14 Aug 2024 12:55:35 -0500 Subject: [PATCH 3/3] add test --- csrf_test.go | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/csrf_test.go b/csrf_test.go index 0559a59..cbc8188 100644 --- a/csrf_test.go +++ b/csrf_test.go @@ -336,6 +336,82 @@ func TestTrustedReferer(t *testing.T) { } } +// TestTrustedOriginPredicateFunc checks that HTTPS requests with a Referer that does not +// match the request URL correctly but is a trusted origin pass CSRF validation. +func TestTrustedOriginPredicateFunc(t *testing.T) { + + testTable := []struct { + predicate func(referer string) bool + shouldPass bool + }{ + {func(referer string) bool { + return referer == "golang.org" + }, true}, + {func(referer string) bool { + return referer == "api.example.com" || referer == "golang.org" + }, true}, + {func(referer string) bool { + return referer == "http://golang.org" + }, false}, + {func(referer string) bool { + return referer == "https://golang.org" + }, false}, + {func(referer string) bool { + return referer == "http://example.com" + }, false}, + {func(referer string) bool { + return referer == "example.com" + }, false}, + } + + for _, item := range testTable { + s := http.NewServeMux() + + p := Protect(testKey, TrustedOriginPredicateFunc(item.predicate))(s) + + var token string + s.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + token = Token(r) + })) + + // Obtain a CSRF cookie via a GET request. + r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + p.ServeHTTP(rr, r) + + // POST the token back in the header. + r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil) + if err != nil { + t.Fatal(err) + } + + setCookie(rr, r) + r.Header.Set("X-CSRF-Token", token) + + // Set a non-matching Referer header. + r.Header.Set("Referer", "http://golang.org/") + + rr = httptest.NewRecorder() + p.ServeHTTP(rr, r) + + if item.shouldPass { + if rr.Code != http.StatusOK { + t.Fatalf("middleware failed to pass to the next handler: got %v want %v", + rr.Code, http.StatusOK) + } + } else { + if rr.Code != http.StatusForbidden { + t.Fatalf("middleware failed reject a non-matching Referer header: got %v want %v", + rr.Code, http.StatusForbidden) + } + } + } +} + // Requests with a valid Referer should pass. func TestWithReferer(t *testing.T) { s := http.NewServeMux()