diff --git a/pkg/security/oidcauth/BUILD.bazel b/pkg/security/oidcauth/BUILD.bazel index f97c8af145a6..2ed7c9c6afee 100644 --- a/pkg/security/oidcauth/BUILD.bazel +++ b/pkg/security/oidcauth/BUILD.bazel @@ -69,6 +69,7 @@ go_test( "//pkg/util/leaktest", "//pkg/util/log", "//pkg/util/randutil", + "//pkg/util/syncutil", "@com_github_cockroachdb_errors//:errors", "@com_github_coreos_go_oidc_v3//oidc", "@com_github_lestrrat_go_jwx_v2//jwa", diff --git a/pkg/security/oidcauth/authorization_oidc_test.go b/pkg/security/oidcauth/authorization_oidc_test.go index 86939cd8b729..74a4b76f6eab 100644 --- a/pkg/security/oidcauth/authorization_oidc_test.go +++ b/pkg/security/oidcauth/authorization_oidc_test.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/errors" "github.com/coreos/go-oidc/v3/oidc" "github.com/lestrrat-go/jwx/v2/jwa" @@ -369,8 +370,14 @@ func TestOIDCAuthorization_TokenPaths(t *testing.T) { // TestOIDCAuthorization_UserinfoPaths exercises the fallback that fetches the // groups list from the provider's /userinfo endpoint when the ID token and -// access token contain no usable claim. Initialization of the test server is -// done only once to avoid flakes related to test server reinitialization. +// access token contain no usable claim. +// +// The test uses a single shared mock OIDC provider for the 5 cases that share +// the same discovery document (all with a userinfo_endpoint). Between cases, +// only the /userinfo response body changes — no cluster settings are modified, +// so no OIDC manager reinitialization occurs. The 2 cases that need different +// discovery documents (absent endpoint, network error) run separately with +// their own mock servers. func TestOIDCAuthorization_UserinfoPaths(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) @@ -401,202 +408,282 @@ func TestOIDCAuthorization_UserinfoPaths(t *testing.T) { _ = jwkKey.Set(jwk.KeyIDKey, "test-key-id") _ = jwkKey.Set(jwk.AlgorithmKey, jwa.RS256) - // The public key is what the verifier will use. publicKey, err := jwk.PublicKeyOf(jwkKey) require.NoError(t, err) jwks := jwk.NewSet() _ = jwks.AddKey(publicKey) - // Set the common oidc cluster settings + // Mutable state for /userinfo responses, protected by a mutex. + var mu syncutil.Mutex + var userinfoStatus int + var userinfoBody string + + // Shared mock OIDC provider for cases with a userinfo endpoint. + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + doc := fmt.Sprintf(`{ + "issuer": "%s", + "token_endpoint": "%s/token", + "userinfo_endpoint": "%s/userinfo", + "jwks_uri": "%s/.well-known/jwks.json" + }`, ts.URL, ts.URL, ts.URL, ts.URL) + _, _ = io.WriteString(w, doc) + case "/.well-known/jwks.json": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jwks) + case "/userinfo": + mu.Lock() + s, b := userinfoStatus, userinfoBody + mu.Unlock() + w.WriteHeader(s) + _, _ = io.WriteString(w, b) + case "/token": + idTok := makeJWT(t, map[string]any{ + "iss": ts.URL, + "email": testUser + "@example.com", + "aud": "client", + "exp": time.Now().Add(time.Hour).Unix(), + }, jwkKey) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ + "access_token":"dummy-access", + "id_token":"%s", + "token_type":"Bearer", + "expires_in":3600 + }`, idTok) + default: + http.NotFound(w, r) + } + })) + defer ts.Close() + + // Set all cluster settings once. Provider URL and redirect are set first + // (while OIDC is disabled, so reloadConfig is a no-op for manager + // creation). OIDCEnabled is set last, triggering exactly one successful + // provider discovery call. st := s.ClusterSettings() + OIDCProviderURL.Override(ctx, &st.SV, ts.URL) + OIDCRedirectURL.Override(ctx, &st.SV, ts.URL+"/callback") OIDCClientID.Override(ctx, &st.SV, "client") OIDCClientSecret.Override(ctx, &st.SV, "secret") OIDCClaimJSONKey.Override(ctx, &st.SV, "email") OIDCPrincipalRegex.Override(ctx, &st.SV, "^([^@]+)@.*$") - OIDCEnabled.Override(ctx, &st.SV, true) OIDCAuthZEnabled.Override(ctx, &st.SV, true) OIDCAuthGroupClaim.Override(ctx, &st.SV, "groups") OIDCAuthUserinfoGroupKey.Override(ctx, &st.SV, "groups") + OIDCEnabled.Override(ctx, &st.SV, true) + + revokeStaleRoles := func(t *testing.T) { + staleRows := sqlDB.QueryStr(t, + `SELECT role FROM system.role_members WHERE member = $1`, testUser) + if len(staleRows) > 0 { + staleRoles := make([]string, len(staleRows)) + for i, r := range staleRows { + staleRoles[i] = r[0] + } + sqlDB.Exec(t, fmt.Sprintf( + `REVOKE %s FROM %s`, strings.Join(staleRoles, ", "), testUser)) + } + } + + doLoginCallback := func(t *testing.T) *http.Response { + resp, err := cl.Get(app.AdminURL().WithPath("/oidc/v1/login").String()) + require.NoError(t, err) + // Find by name; in shared-process multi-tenant mode the tenant + // cookie precedes oidc_secret. + var cookie *http.Cookie + for _, c := range resp.Cookies() { + if c.Name == secretCookieName { + cookie = c + break + } + } + require.NotNil(t, cookie, "expected oidc_secret cookie") + loc, _ := url.Parse(resp.Header.Get("Location")) + state := loc.Query().Get("state") + + cb, _ := http.NewRequest("GET", + app.AdminURL().WithPath("/oidc/v1/callback").String(), nil) + cb.AddCookie(cookie) + q := cb.URL.Query() + q.Set("state", state) + q.Set("code", "dummy") + cb.URL.RawQuery = q.Encode() + cbResp, err := cl.Do(cb) + require.NoError(t, err) + return cbResp + } + + checkRoles := func(t *testing.T, wantRoles []string) { + rows := sqlDB.QueryStr(t, fmt.Sprintf( + `SELECT role FROM system.role_members WHERE member = '%s' ORDER BY role`, + testUser)) + var got []string + for _, r := range rows { + got = append(got, r[0]) + } + require.ElementsMatch(t, wantRoles, got) + } + // Cases that share the mock server — only the /userinfo response varies. type tc struct { name string - discoveryDoc string userinfoStatus int userinfoBody string wantRoles []string // nil => expect 403 wantErr bool } - cases := []tc{ + sharedCases := []tc{ { - name: "userinfo success", - discoveryDoc: `{ - "issuer": "{{url}}", - "token_endpoint": "{{url}}/token", - "userinfo_endpoint": "{{url}}/userinfo", - "jwks_uri": "{{url}}/.well-known/jwks.json" - }`, + name: "userinfo success", userinfoStatus: http.StatusOK, userinfoBody: fmt.Sprintf(`{"groups":["%s","%s"]}`, roleOwners, roleUsers), wantRoles: []string{roleOwners, roleUsers}, }, { - name: "userinfo endpoint absent", - discoveryDoc: `{ - "issuer": "{{url}}", - "token_endpoint": "{{url}}/token", - "jwks_uri": "{{url}}/.well-known/jwks.json" - }`, - // no userinfo: should end in forbidden - wantRoles: nil, - wantErr: true, - }, - { - name: "userinfo network error", - discoveryDoc: `{ - "issuer": "{{url}}", - "token_endpoint": "{{url}}/token", - "userinfo_endpoint": "http://127.0.0.1:0/userinfo", - "jwks_uri": "{{url}}/.well-known/jwks.json" - }`, - wantRoles: nil, - wantErr: true, - }, - { - name: "userinfo empty groups list", - discoveryDoc: `{ - "issuer": "{{url}}", - "token_endpoint": "{{url}}/token", - "userinfo_endpoint": "{{url}}/userinfo", - "jwks_uri": "{{url}}/.well-known/jwks.json" - }`, + name: "userinfo empty groups list", userinfoStatus: http.StatusOK, userinfoBody: `{"groups":[]}`, - wantRoles: nil, // Roles should be revoked, login denied. wantErr: true, }, { - name: "userinfo non-standard body", - discoveryDoc: `{ - "issuer": "{{url}}", - "token_endpoint": "{{url}}/token", - "userinfo_endpoint": "{{url}}/userinfo", - "jwks_uri": "{{url}}/.well-known/jwks.json" - }`, + name: "userinfo non-standard body", userinfoStatus: http.StatusOK, userinfoBody: `this is not json`, - wantRoles: nil, wantErr: true, }, { - name: "userinfo missing groups claim", - discoveryDoc: `{ - "issuer": "{{url}}", - "token_endpoint": "{{url}}/token", - "userinfo_endpoint": "{{url}}/userinfo", - "jwks_uri": "{{url}}/.well-known/jwks.json" - }`, + name: "userinfo missing groups claim", userinfoStatus: http.StatusOK, userinfoBody: `{"email":"test@example.com"}`, - wantRoles: nil, wantErr: true, }, { - name: "userinfo invalid groups claim", - discoveryDoc: `{ - "issuer": "{{url}}", - "token_endpoint": "{{url}}/token", - "userinfo_endpoint": "{{url}}/userinfo", - "jwks_uri": "{{url}}/.well-known/jwks.json" - }`, + name: "userinfo invalid groups claim", userinfoStatus: http.StatusOK, userinfoBody: `{"groups":not-a-list}`, - wantRoles: nil, wantErr: true, }, } - for _, tc := range cases { + for _, tc := range sharedCases { tc := tc t.Run(tc.name, func(t *testing.T) { + revokeStaleRoles(t) + + mu.Lock() + userinfoStatus = tc.userinfoStatus + userinfoBody = tc.userinfoBody + mu.Unlock() + + cbResp := doLoginCallback(t) - var ts *httptest.Server - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wantStatus := http.StatusTemporaryRedirect + if tc.wantErr { + wantStatus = http.StatusForbidden + } + require.Equal(t, wantStatus, cbResp.StatusCode) + checkRoles(t, tc.wantRoles) + }) + } + + // Special case: discovery document has no userinfo_endpoint. The provider + // cannot fall back to userinfo, so the login must be denied. + t.Run("userinfo endpoint absent", func(t *testing.T) { + revokeStaleRoles(t) + + var absentTS *httptest.Server + absentTS = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/.well-known/openid-configuration": - doc := strings.ReplaceAll(tc.discoveryDoc, "{{url}}", ts.URL) + doc := fmt.Sprintf(`{ + "issuer": "%s", + "token_endpoint": "%s/token", + "jwks_uri": "%s/.well-known/jwks.json" + }`, absentTS.URL, absentTS.URL, absentTS.URL) _, _ = io.WriteString(w, doc) case "/.well-known/jwks.json": w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(jwks) - require.NoError(t, err) - case "/userinfo": - w.WriteHeader(tc.userinfoStatus) - _, _ = io.WriteString(w, tc.userinfoBody) + _ = json.NewEncoder(w).Encode(jwks) case "/token": idTok := makeJWT(t, map[string]any{ - "iss": ts.URL, // issuer must match provider URL + "iss": absentTS.URL, "email": testUser + "@example.com", "aud": "client", "exp": time.Now().Add(time.Hour).Unix(), }, jwkKey) - resp := `{ + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ "access_token":"dummy-access", - "id_token":"` + idTok + `", + "id_token":"%s", "token_type":"Bearer", "expires_in":3600 - }` - w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, resp) + }`, idTok) default: http.NotFound(w, r) } - }) - ts = httptest.NewServer(handler) - defer ts.Close() + })) + defer absentTS.Close() - OIDCProviderURL.Override(ctx, &st.SV, ts.URL) - OIDCRedirectURL.Override(ctx, &st.SV, ts.URL+"/callback") + OIDCProviderURL.Override(ctx, &st.SV, absentTS.URL) + OIDCRedirectURL.Override(ctx, &st.SV, absentTS.URL+"/callback") - resp, err := cl.Get(app.AdminURL().WithPath("/oidc/v1/login").String()) - require.NoError(t, err) - // Find by name; in shared-process multi-tenant mode the tenant cookie precedes oidc_secret. - var cookie *http.Cookie - for _, c := range resp.Cookies() { - if c.Name == secretCookieName { - cookie = c - break - } - } - require.NotNil(t, cookie, "expected oidc_secret cookie") - loc, _ := url.Parse(resp.Header.Get("Location")) - state := loc.Query().Get("state") + cbResp := doLoginCallback(t) + require.Equal(t, http.StatusForbidden, cbResp.StatusCode) + checkRoles(t, nil) + }) - cb, _ := http.NewRequest("GET", app.AdminURL().WithPath("/oidc/v1/callback").String(), nil) - cb.AddCookie(cookie) - q := cb.URL.Query() - q.Set("state", state) - q.Set("code", "dummy") - cb.URL.RawQuery = q.Encode() - cbResp, err := cl.Do(cb) - require.NoError(t, err) + // Special case: discovery document points userinfo_endpoint to an + // unreachable address. The network error must be handled gracefully. + t.Run("userinfo network error", func(t *testing.T) { + revokeStaleRoles(t) - wantStatus := http.StatusTemporaryRedirect - if tc.wantErr { - // The two error cases should now result in a Forbidden, because the - // userinfo fallback will fail gracefully but no groups will be found. - wantStatus = http.StatusForbidden - } - require.Equal(t, wantStatus, cbResp.StatusCode) + var netErrTS *httptest.Server + netErrTS = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + doc := fmt.Sprintf(`{ + "issuer": "%s", + "token_endpoint": "%s/token", + "userinfo_endpoint": "http://127.0.0.1:0/userinfo", + "jwks_uri": "%s/.well-known/jwks.json" + }`, netErrTS.URL, netErrTS.URL, netErrTS.URL) + _, _ = io.WriteString(w, doc) + case "/.well-known/jwks.json": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jwks) + case "/token": + idTok := makeJWT(t, map[string]any{ + "iss": netErrTS.URL, + "email": testUser + "@example.com", + "aud": "client", + "exp": time.Now().Add(time.Hour).Unix(), + }, jwkKey) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ + "access_token":"dummy-access", + "id_token":"%s", + "token_type":"Bearer", + "expires_in":3600 + }`, idTok) + default: + http.NotFound(w, r) + } + })) + defer netErrTS.Close() - rows := sqlDB.QueryStr(t, fmt.Sprintf(`SELECT role FROM system.role_members WHERE member = '%s' ORDER BY role`, testUser)) - var got []string - for _, r := range rows { - got = append(got, r[0]) - } - require.ElementsMatch(t, tc.wantRoles, got) - }) - } + OIDCProviderURL.Override(ctx, &st.SV, netErrTS.URL) + OIDCRedirectURL.Override(ctx, &st.SV, netErrTS.URL+"/callback") + + cbResp := doLoginCallback(t) + require.Equal(t, http.StatusForbidden, cbResp.StatusCode) + checkRoles(t, nil) + }) } // TestOIDCAuthorization_RoleGrantAndRevoke tests that roles are granted and revoked as expected