Skip to content

Commit 0b2a15d

Browse files
authored
Merge pull request #1 from ngrok-oss/add-trusted-origin-predicate-func
adds TrustedOriginPredicateFunc option, tests
2 parents a009743 + 9baf069 commit 0b2a15d

File tree

4 files changed

+110
-15
lines changed

4 files changed

+110
-15
lines changed

csrf.go

+17-12
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ type options struct {
8585
Path string
8686
// Note that the function and field names match the case of the associated
8787
// http.Cookie field instead of the "correct" HTTPOnly name that golint suggests.
88-
HttpOnly bool
89-
Secure bool
90-
SameSite SameSiteMode
91-
RequestHeader string
92-
FieldName string
93-
ErrorHandler http.Handler
94-
CookieName string
95-
TrustedOrigins []string
88+
HttpOnly bool
89+
Secure bool
90+
SameSite SameSiteMode
91+
RequestHeader string
92+
FieldName string
93+
ErrorHandler http.Handler
94+
CookieName string
95+
TrustedOrigins []string
96+
TrustedOriginPredicateFunc func(referer string) bool
9697
}
9798

9899
// Protect is HTTP middleware that provides Cross-Site Request Forgery
@@ -258,10 +259,14 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
258259
valid := sameOrigin(r.URL, referer)
259260

260261
if !valid {
261-
for _, trustedOrigin := range cs.opts.TrustedOrigins {
262-
if referer.Host == trustedOrigin {
263-
valid = true
264-
break
262+
if cs.opts.TrustedOriginPredicateFunc != nil {
263+
valid = cs.opts.TrustedOriginPredicateFunc(referer.Host)
264+
} else {
265+
for _, trustedOrigin := range cs.opts.TrustedOrigins {
266+
if referer.Host == trustedOrigin {
267+
valid = true
268+
break
269+
}
265270
}
266271
}
267272
}

csrf_test.go

+76
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,82 @@ func TestTrustedReferer(t *testing.T) {
336336
}
337337
}
338338

339+
// TestTrustedOriginPredicateFunc checks that HTTPS requests with a Referer that does not
340+
// match the request URL correctly but is a trusted origin pass CSRF validation.
341+
func TestTrustedOriginPredicateFunc(t *testing.T) {
342+
343+
testTable := []struct {
344+
predicate func(referer string) bool
345+
shouldPass bool
346+
}{
347+
{func(referer string) bool {
348+
return referer == "golang.org"
349+
}, true},
350+
{func(referer string) bool {
351+
return referer == "api.example.com" || referer == "golang.org"
352+
}, true},
353+
{func(referer string) bool {
354+
return referer == "http://golang.org"
355+
}, false},
356+
{func(referer string) bool {
357+
return referer == "https://golang.org"
358+
}, false},
359+
{func(referer string) bool {
360+
return referer == "http://example.com"
361+
}, false},
362+
{func(referer string) bool {
363+
return referer == "example.com"
364+
}, false},
365+
}
366+
367+
for _, item := range testTable {
368+
s := http.NewServeMux()
369+
370+
p := Protect(testKey, TrustedOriginPredicateFunc(item.predicate))(s)
371+
372+
var token string
373+
s.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
374+
token = Token(r)
375+
}))
376+
377+
// Obtain a CSRF cookie via a GET request.
378+
r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil)
379+
if err != nil {
380+
t.Fatal(err)
381+
}
382+
383+
rr := httptest.NewRecorder()
384+
p.ServeHTTP(rr, r)
385+
386+
// POST the token back in the header.
387+
r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil)
388+
if err != nil {
389+
t.Fatal(err)
390+
}
391+
392+
setCookie(rr, r)
393+
r.Header.Set("X-CSRF-Token", token)
394+
395+
// Set a non-matching Referer header.
396+
r.Header.Set("Referer", "http://golang.org/")
397+
398+
rr = httptest.NewRecorder()
399+
p.ServeHTTP(rr, r)
400+
401+
if item.shouldPass {
402+
if rr.Code != http.StatusOK {
403+
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
404+
rr.Code, http.StatusOK)
405+
}
406+
} else {
407+
if rr.Code != http.StatusForbidden {
408+
t.Fatalf("middleware failed reject a non-matching Referer header: got %v want %v",
409+
rr.Code, http.StatusForbidden)
410+
}
411+
}
412+
}
413+
}
414+
339415
// Requests with a valid Referer should pass.
340416
func TestWithReferer(t *testing.T) {
341417
s := http.NewServeMux()

doc.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ field.
1717
gorilla/csrf is easy to use: add the middleware to individual handlers with
1818
the below:
1919
20-
CSRF := csrf.Protect([]byte("32-byte-long-auth-key"))
21-
http.HandlerFunc("/route", CSRF(YourHandler))
20+
CSRF := csrf.Protect([]byte("32-byte-long-auth-key"))
21+
http.HandlerFunc("/route", CSRF(YourHandler))
2222
2323
... and then collect the token with `csrf.Token(r)` before passing it to the
2424
template, JSON body or HTTP header (you pick!). gorilla/csrf inspects the form body
@@ -171,6 +171,5 @@ important.
171171
and the one-time-pad used for masking them.
172172
173173
This library does not seek to be adventurous.
174-
175174
*/
176175
package csrf

options.go

+15
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,21 @@ func TrustedOrigins(origins []string) Option {
131131
}
132132
}
133133

134+
// TrustedOriginPredicateFunc configures a predicate function that can be used
135+
// to determine if a given Referer is trusted.
136+
// Like TrustedOrigins, this will allow cross-domain CSRF use-cases - e.g. where
137+
// the front-end is served from a different domain than the API server - to
138+
// correctly pass a CSRF check.
139+
// However, this function allows for more complex logic to be applied to determine
140+
// if a Referer is trusted than strict equality string matching.
141+
//
142+
// You should only pass origins you own or have full control over.
143+
func TrustedOriginPredicateFunc(f func(referer string) bool) Option {
144+
return func(cs *csrf) {
145+
cs.opts.TrustedOriginPredicateFunc = f
146+
}
147+
}
148+
134149
// setStore sets the store used by the CSRF middleware.
135150
// Note: this is private (for now) to allow for internal API changes.
136151
func setStore(s store) Option {

0 commit comments

Comments
 (0)