Skip to content

Commit 22f3138

Browse files
Skip basic auth when auth token is set (#511)
* Skip basic auth when auth token is set In addition, I also rename basicAuth to authBasic for naming consistency with authToken. * Rename test
1 parent 966a34d commit 22f3138

File tree

4 files changed

+106
-8
lines changed

4 files changed

+106
-8
lines changed
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
7+
"github.com/go-chi/chi/v5/middleware"
8+
)
9+
10+
func BasicAuth(realm string, creds map[string]string) func(next http.Handler) http.Handler {
11+
return func(next http.Handler) http.Handler {
12+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13+
authHeader := strings.Split(r.Header.Get("Authorization"), " ")
14+
if skipBasicAuth(authHeader) {
15+
next.ServeHTTP(w, r)
16+
return
17+
}
18+
middleware.BasicAuth("restricted", creds)(next).ServeHTTP(w, r)
19+
})
20+
}
21+
}
22+
23+
// skipBasicAuth skips basic auth middleware when the auth token is set
24+
func skipBasicAuth(authHeader []string) bool {
25+
return authToken != nil &&
26+
len(authHeader) >= 2 &&
27+
authHeader[0] == "Bearer"
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestBasicAuth(t *testing.T) {
12+
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13+
w.WriteHeader(http.StatusOK)
14+
w.Write([]byte("OK"))
15+
})
16+
fakeAuthHeader := "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva"
17+
fakeAuthToken := AuthToken{
18+
Token: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva",
19+
}
20+
testCase := []struct {
21+
name string
22+
authHeader string
23+
authToken *AuthToken
24+
httpStatus int
25+
}{
26+
{
27+
name: "auth header set, auth token set",
28+
authHeader: fakeAuthHeader,
29+
authToken: &fakeAuthToken,
30+
httpStatus: http.StatusOK,
31+
},
32+
{
33+
name: "auth header set, auth token unset",
34+
authHeader: fakeAuthHeader,
35+
authToken: nil,
36+
httpStatus: http.StatusUnauthorized,
37+
},
38+
{
39+
name: "auth header unset, auth token set",
40+
authHeader: "",
41+
authToken: &fakeAuthToken,
42+
httpStatus: http.StatusUnauthorized,
43+
},
44+
{
45+
name: "auth header unset, auth token unset",
46+
authHeader: "",
47+
authToken: nil,
48+
httpStatus: http.StatusUnauthorized,
49+
},
50+
}
51+
// incorrectCreds triggers HTTP 401 Unauthorized upon basic auth
52+
incorrectCreds := map[string]string{
53+
"INCORRECT_USERNAME": "INCORRECT_PASSWORD",
54+
}
55+
for _, tc := range testCase {
56+
t.Run(tc.name, func(t *testing.T) {
57+
r, err := http.NewRequest("GET", "/test", nil)
58+
require.NoError(t, err)
59+
r.Header.Add("Authorization", tc.authHeader)
60+
w := httptest.NewRecorder()
61+
authToken = tc.authToken
62+
BasicAuth(
63+
"restricted",
64+
incorrectCreds,
65+
)(testHandler).ServeHTTP(w, r)
66+
require.Equal(t, tc.httpStatus, w.Result().StatusCode)
67+
})
68+
}
69+
}

service/frontend/middleware/global.go

+8-7
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler {
1717
next = TokenAuth("restricted", authToken.Token)(next)
1818
}
1919

20-
if basicAuth != nil {
21-
next = middleware.BasicAuth(
22-
"restricted", map[string]string{basicAuth.Username: basicAuth.Password},
20+
if authBasic != nil {
21+
next = BasicAuth(
22+
"restricted",
23+
map[string]string{authBasic.Username: authBasic.Password},
2324
)(next)
2425
}
2526
next = prefixChecker(next)
@@ -29,17 +30,17 @@ func SetupGlobalMiddleware(handler http.Handler) http.Handler {
2930

3031
var (
3132
defaultHandler http.Handler
32-
basicAuth *BasicAuth
33+
authBasic *AuthBasic
3334
authToken *AuthToken
3435
)
3536

3637
type Options struct {
3738
Handler http.Handler
38-
BasicAuth *BasicAuth
39+
AuthBasic *AuthBasic
3940
AuthToken *AuthToken
4041
}
4142

42-
type BasicAuth struct {
43+
type AuthBasic struct {
4344
Username string
4445
Password string
4546
}
@@ -50,7 +51,7 @@ type AuthToken struct {
5051

5152
func Setup(opts *Options) {
5253
defaultHandler = opts.Handler
53-
basicAuth = opts.BasicAuth
54+
authBasic = opts.AuthBasic
5455
authToken = opts.AuthToken
5556
}
5657

service/frontend/server/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func (svr *Server) Serve(ctx context.Context) (err error) {
9191
}
9292
}
9393
if svr.basicAuth != nil {
94-
middlewareOptions.BasicAuth = &pkgmiddleware.BasicAuth{
94+
middlewareOptions.AuthBasic = &pkgmiddleware.AuthBasic{
9595
Username: svr.basicAuth.Username,
9696
Password: svr.basicAuth.Password,
9797
}

0 commit comments

Comments
 (0)