Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds TrustedOriginPredicateFunc option, tests #1

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions csrf.go
Original file line number Diff line number Diff line change
@@ -85,14 +85,15 @@ type options struct {
Path string
// Note that the function and field names match the case of the associated
// http.Cookie field instead of the "correct" HTTPOnly name that golint suggests.
HttpOnly bool
Secure bool
SameSite SameSiteMode
RequestHeader string
FieldName string
ErrorHandler http.Handler
CookieName string
TrustedOrigins []string
HttpOnly bool
Secure bool
SameSite SameSiteMode
RequestHeader string
FieldName string
ErrorHandler http.Handler
CookieName string
TrustedOrigins []string
TrustedOriginPredicateFunc func(referer string) bool
}

// Protect is HTTP middleware that provides Cross-Site Request Forgery
@@ -258,10 +259,14 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
valid := sameOrigin(r.URL, referer)

if !valid {
for _, trustedOrigin := range cs.opts.TrustedOrigins {
if referer.Host == trustedOrigin {
valid = true
break
if cs.opts.TrustedOriginPredicateFunc != nil {
valid = cs.opts.TrustedOriginPredicateFunc(referer.Host)
} else {
for _, trustedOrigin := range cs.opts.TrustedOrigins {
if referer.Host == trustedOrigin {
valid = true
break
}
}
}
}
76 changes: 76 additions & 0 deletions csrf_test.go
Original file line number Diff line number Diff line change
@@ -336,6 +336,82 @@ func TestTrustedReferer(t *testing.T) {
}
}

// TestTrustedOriginPredicateFunc checks that HTTPS requests with a Referer that does not
// match the request URL correctly but is a trusted origin pass CSRF validation.
func TestTrustedOriginPredicateFunc(t *testing.T) {

testTable := []struct {
predicate func(referer string) bool
shouldPass bool
}{
{func(referer string) bool {
return referer == "golang.org"
}, true},
{func(referer string) bool {
return referer == "api.example.com" || referer == "golang.org"
}, true},
{func(referer string) bool {
return referer == "http://golang.org"
}, false},
{func(referer string) bool {
return referer == "https://golang.org"
}, false},
{func(referer string) bool {
return referer == "http://example.com"
}, false},
{func(referer string) bool {
return referer == "example.com"
}, false},
}

for _, item := range testTable {
s := http.NewServeMux()

p := Protect(testKey, TrustedOriginPredicateFunc(item.predicate))(s)

var token string
s.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
token = Token(r)
}))

// Obtain a CSRF cookie via a GET request.
r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
p.ServeHTTP(rr, r)

// POST the token back in the header.
r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil)
if err != nil {
t.Fatal(err)
}

setCookie(rr, r)
r.Header.Set("X-CSRF-Token", token)

// Set a non-matching Referer header.
r.Header.Set("Referer", "http://golang.org/")

rr = httptest.NewRecorder()
p.ServeHTTP(rr, r)

if item.shouldPass {
if rr.Code != http.StatusOK {
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
rr.Code, http.StatusOK)
}
} else {
if rr.Code != http.StatusForbidden {
t.Fatalf("middleware failed reject a non-matching Referer header: got %v want %v",
rr.Code, http.StatusForbidden)
}
}
}
}

// Requests with a valid Referer should pass.
func TestWithReferer(t *testing.T) {
s := http.NewServeMux()
5 changes: 2 additions & 3 deletions doc.go
Original file line number Diff line number Diff line change
@@ -17,8 +17,8 @@ field.
gorilla/csrf is easy to use: add the middleware to individual handlers with
the below:

CSRF := csrf.Protect([]byte("32-byte-long-auth-key"))
http.HandlerFunc("/route", CSRF(YourHandler))
CSRF := csrf.Protect([]byte("32-byte-long-auth-key"))
http.HandlerFunc("/route", CSRF(YourHandler))

... and then collect the token with `csrf.Token(r)` before passing it to the
template, JSON body or HTTP header (you pick!). gorilla/csrf inspects the form body
@@ -171,6 +171,5 @@ important.
and the one-time-pad used for masking them.

This library does not seek to be adventurous.

*/
package csrf
15 changes: 15 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -131,6 +131,21 @@ func TrustedOrigins(origins []string) Option {
}
}

// TrustedOriginPredicateFunc configures a predicate function that can be used
// to determine if a given Referer is trusted.
// Like TrustedOrigins, this will allow cross-domain CSRF use-cases - e.g. where
// the front-end is served from a different domain than the API server - to
// correctly pass a CSRF check.
// However, this function allows for more complex logic to be applied to determine
// if a Referer is trusted than strict equality string matching.
//
// You should only pass origins you own or have full control over.
func TrustedOriginPredicateFunc(f func(referer string) bool) Option {
return func(cs *csrf) {
cs.opts.TrustedOriginPredicateFunc = f
}
}

// setStore sets the store used by the CSRF middleware.
// Note: this is private (for now) to allow for internal API changes.
func setStore(s store) Option {