Skip to content

Commit 2526406

Browse files
Get IDP endpoints from well-known config (#470)
* Get IDP endpoints from well-known config * Refactor parsing of the well known config * Fix lint * Add unit test * Fix resource-manager custom endpoint
1 parent 394f4dd commit 2526406

File tree

11 files changed

+236
-64
lines changed

11 files changed

+236
-64
lines changed

internal/cmd/config/set/set.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ import (
1717
)
1818

1919
const (
20-
sessionTimeLimitFlag = "session-time-limit"
21-
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
22-
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
23-
allowedUrlDomainFlag = "allowed-url-domain"
20+
sessionTimeLimitFlag = "session-time-limit"
21+
identityProviderCustomWellKnownConfigurationFlag = "identity-provider-custom-well-known-configuration"
22+
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
23+
allowedUrlDomainFlag = "allowed-url-domain"
2424

2525
authorizationCustomEndpointFlag = "authorization-custom-endpoint"
2626
dnsCustomEndpointFlag = "dns-custom-endpoint"
@@ -131,7 +131,7 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e
131131

132132
func configureFlags(cmd *cobra.Command) {
133133
cmd.Flags().String(sessionTimeLimitFlag, "", "Maximum time before authentication is required again. After this time, you will be prompted to login again to execute commands that require authentication. Can't be larger than 24h. Requires authentication after being set to take effect. Examples: 3h, 5h30m40s (BETA: currently values greater than 2h have no effect)")
134-
cmd.Flags().String(identityProviderCustomEndpointFlag, "", "Identity Provider base URL, used for user authentication")
134+
cmd.Flags().String(identityProviderCustomWellKnownConfigurationFlag, "", "Identity Provider well-known OpenID configuration URL, used for user authentication")
135135
cmd.Flags().String(identityProviderCustomClientIdFlag, "", "Identity Provider client ID, used for user authentication")
136136
cmd.Flags().String(allowedUrlDomainFlag, "", `Domain name, used for the verification of the URLs that are given in the custom identity provider endpoint and "STACKIT curl" command`)
137137
cmd.Flags().String(observabilityCustomEndpointFlag, "", "Observability API base URL, used in calls to this API")
@@ -159,7 +159,7 @@ func configureFlags(cmd *cobra.Command) {
159159

160160
err := viper.BindPFlag(config.SessionTimeLimitKey, cmd.Flags().Lookup(sessionTimeLimitFlag))
161161
cobra.CheckErr(err)
162-
err = viper.BindPFlag(config.IdentityProviderCustomEndpointKey, cmd.Flags().Lookup(identityProviderCustomEndpointFlag))
162+
err = viper.BindPFlag(config.IdentityProviderCustomWellKnownConfigurationKey, cmd.Flags().Lookup(identityProviderCustomWellKnownConfigurationFlag))
163163
cobra.CheckErr(err)
164164
err = viper.BindPFlag(config.IdentityProviderCustomClientIdKey, cmd.Flags().Lookup(identityProviderCustomClientIdFlag))
165165
cobra.CheckErr(err)
@@ -190,7 +190,7 @@ func configureFlags(cmd *cobra.Command) {
190190
cobra.CheckErr(err)
191191
err = viper.BindPFlag(config.RedisCustomEndpointKey, cmd.Flags().Lookup(redisCustomEndpointFlag))
192192
cobra.CheckErr(err)
193-
err = viper.BindPFlag(config.ResourceManagerEndpointKey, cmd.Flags().Lookup(skeCustomEndpointFlag))
193+
err = viper.BindPFlag(config.ResourceManagerEndpointKey, cmd.Flags().Lookup(resourceManagerCustomEndpointFlag))
194194
cobra.CheckErr(err)
195195
err = viper.BindPFlag(config.SecretsManagerCustomEndpointKey, cmd.Flags().Lookup(secretsManagerCustomEndpointFlag))
196196
cobra.CheckErr(err)

internal/cmd/config/unset/unset.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ const (
2020
projectIdFlag = globalflags.ProjectIdFlag
2121
verbosityFlag = globalflags.VerbosityFlag
2222

23-
sessionTimeLimitFlag = "session-time-limit"
24-
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
25-
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
26-
allowedUrlDomainFlag = "allowed-url-domain"
23+
sessionTimeLimitFlag = "session-time-limit"
24+
identityProviderCustomWellKnownConfigurationFlag = "identity-provider-custom-well-known-configuration"
25+
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
26+
allowedUrlDomainFlag = "allowed-url-domain"
2727

2828
authorizationCustomEndpointFlag = "authorization-custom-endpoint"
2929
dnsCustomEndpointFlag = "dns-custom-endpoint"
@@ -121,7 +121,7 @@ func NewCmd(p *print.Printer) *cobra.Command {
121121
viper.Set(config.SessionTimeLimitKey, config.SessionTimeLimitDefault)
122122
}
123123
if model.IdentityProviderCustomEndpoint {
124-
viper.Set(config.IdentityProviderCustomEndpointKey, "")
124+
viper.Set(config.IdentityProviderCustomWellKnownConfigurationKey, "")
125125
}
126126
if model.IdentityProviderCustomClientID {
127127
viper.Set(config.IdentityProviderCustomClientIdKey, "")
@@ -215,7 +215,7 @@ func configureFlags(cmd *cobra.Command) {
215215
cmd.Flags().Bool(verbosityFlag, false, "Verbosity of the CLI")
216216

217217
cmd.Flags().Bool(sessionTimeLimitFlag, false, fmt.Sprintf("Maximum time before authentication is required again. If unset, defaults to %s", config.SessionTimeLimitDefault))
218-
cmd.Flags().Bool(identityProviderCustomEndpointFlag, false, "Identity Provider base URL. If unset, uses the default base URL")
218+
cmd.Flags().Bool(identityProviderCustomWellKnownConfigurationFlag, false, "Identity Provider well-known OpenID configuration URL. If unset, uses the default identity provider")
219219
cmd.Flags().Bool(identityProviderCustomClientIdFlag, false, "Identity Provider client ID, used for user authentication")
220220
cmd.Flags().Bool(allowedUrlDomainFlag, false, fmt.Sprintf("Domain name, used for the verification of the URLs that are given in the IDP endpoint and curl commands. If unset, defaults to %s", config.AllowedUrlDomainDefault))
221221

@@ -251,7 +251,7 @@ func parseInput(p *print.Printer, cmd *cobra.Command) *inputModel {
251251
Verbosity: flags.FlagToBoolValue(p, cmd, verbosityFlag),
252252

253253
SessionTimeLimit: flags.FlagToBoolValue(p, cmd, sessionTimeLimitFlag),
254-
IdentityProviderCustomEndpoint: flags.FlagToBoolValue(p, cmd, identityProviderCustomEndpointFlag),
254+
IdentityProviderCustomEndpoint: flags.FlagToBoolValue(p, cmd, identityProviderCustomWellKnownConfigurationFlag),
255255
IdentityProviderCustomClientID: flags.FlagToBoolValue(p, cmd, identityProviderCustomClientIdFlag),
256256
AllowedUrlDomain: flags.FlagToBoolValue(p, cmd, allowedUrlDomainFlag),
257257

internal/cmd/config/unset/unset_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ func fixtureFlagValues(mods ...func(flagValues map[string]bool)) map[string]bool
1616
projectIdFlag: true,
1717
verbosityFlag: true,
1818

19-
sessionTimeLimitFlag: true,
20-
identityProviderCustomEndpointFlag: true,
21-
identityProviderCustomClientIdFlag: true,
22-
allowedUrlDomainFlag: true,
19+
sessionTimeLimitFlag: true,
20+
identityProviderCustomWellKnownConfigurationFlag: true,
21+
identityProviderCustomClientIdFlag: true,
22+
allowedUrlDomainFlag: true,
2323

2424
authorizationCustomEndpointFlag: true,
2525
dnsCustomEndpointFlag: true,
@@ -157,7 +157,7 @@ func TestParseInput(t *testing.T) {
157157
{
158158
description: "identity provider custom endpoint empty",
159159
flagValues: fixtureFlagValues(func(flagValues map[string]bool) {
160-
flagValues[identityProviderCustomEndpointFlag] = false
160+
flagValues[identityProviderCustomWellKnownConfigurationFlag] = false
161161
}),
162162
isValid: true,
163163
expectedModel: fixtureInputModel(func(model *inputModel) {

internal/pkg/auth/storage.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const (
3737
SERVICE_ACCOUNT_KEY authFieldKey = "service_account_key"
3838
PRIVATE_KEY authFieldKey = "private_key"
3939
TOKEN_CUSTOM_ENDPOINT authFieldKey = "token_custom_endpoint"
40+
IDP_TOKEN_ENDPOINT authFieldKey = "idp_token_endpoint" //nolint:gosec // linter false positive
4041
)
4142

4243
const (
@@ -57,6 +58,7 @@ var authFieldKeys = []authFieldKey{
5758
SERVICE_ACCOUNT_KEY,
5859
PRIVATE_KEY,
5960
TOKEN_CUSTOM_ENDPOINT,
61+
IDP_TOKEN_ENDPOINT,
6062
authFlowType,
6163
}
6264

internal/pkg/auth/user_login.go

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ import (
2323
)
2424

2525
const (
26-
defaultIDPEndpoint = "https://accounts.stackit.cloud/oauth/v2"
27-
defaultCLIClientID = "stackit-cli-0000-0000-000000000001"
26+
defaultWellKnownConfig = "https://accounts.stackit.cloud/.well-known/openid-configuration"
27+
defaultCLIClientID = "stackit-cli-0000-0000-000000000001"
2828

2929
loginSuccessPath = "/login-successful"
3030
stackitLandingPage = "https://www.stackit.de"
@@ -44,20 +44,31 @@ type User struct {
4444
Email string
4545
}
4646

47+
type apiClient interface {
48+
Do(req *http.Request) (*http.Response, error)
49+
}
50+
4751
// AuthorizeUser implements the PKCE OAuth2 flow.
4852
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
49-
idpEndpoint, err := getIDPEndpoint()
53+
idpWellKnownConfigURL, err := getIDPWellKnownConfigURL()
5054
if err != nil {
51-
return err
55+
return fmt.Errorf("get IDP well-known configuration: %w", err)
5256
}
53-
if idpEndpoint != defaultIDPEndpoint {
54-
p.Warn("You are using a custom identity provider (%s) for authentication.\n", idpEndpoint)
57+
if idpWellKnownConfigURL != defaultWellKnownConfig {
58+
p.Warn("You are using a custom identity provider well-known configuration (%s) for authentication.\n", idpWellKnownConfigURL)
5559
err := p.PromptForEnter("Press Enter to proceed with the login...")
5660
if err != nil {
5761
return err
5862
}
5963
}
6064

65+
p.Debug(print.DebugLevel, "get IDP well-known configuration from %s", idpWellKnownConfigURL)
66+
httpClient := &http.Client{}
67+
idpWellKnownConfig, err := parseWellKnownConfiguration(httpClient, idpWellKnownConfigURL)
68+
if err != nil {
69+
return fmt.Errorf("parse IDP well-known configuration: %w", err)
70+
}
71+
6172
idpClientID, err := getIDPClientID()
6273
if err != nil {
6374
return err
@@ -100,7 +111,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
100111
conf := &oauth2.Config{
101112
ClientID: idpClientID,
102113
Endpoint: oauth2.Endpoint{
103-
AuthURL: fmt.Sprintf("%s/authorize", idpEndpoint),
114+
AuthURL: idpWellKnownConfig.AuthorizationEndpoint,
104115
},
105116
Scopes: []string{"openid offline_access email"},
106117
RedirectURL: redirectURL,
@@ -147,7 +158,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
147158
p.Debug(print.DebugLevel, "trading authorization code for access and refresh tokens")
148159

149160
// Trade the authorization code and the code verifier for access and refresh tokens
150-
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, idpClientID, codeVerifier, code, redirectURL)
161+
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpWellKnownConfig, idpClientID, codeVerifier, code, redirectURL)
151162
if err != nil {
152163
errServer = fmt.Errorf("retrieve tokens: %w", err)
153164
return
@@ -222,7 +233,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
222233
})
223234

224235
p.Debug(print.DebugLevel, "opening browser for authentication")
225-
p.Debug(print.DebugLevel, "using authentication server on %s", idpEndpoint)
236+
p.Debug(print.DebugLevel, "using authentication server on %s", idpWellKnownConfig.Issuer)
226237
p.Debug(print.DebugLevel, "using client ID %s for authentication ", idpClientID)
227238

228239
// Open a browser window to the authorizationURL
@@ -248,9 +259,8 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
248259
}
249260

250261
// getUserAccessAndRefreshTokens trades the authorization code retrieved from the first OAuth2 leg for an access token and a refresh token
251-
func getUserAccessAndRefreshTokens(authDomain, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
252-
// Set the authUrl and form-encoded data for the POST to the access token endpoint
253-
authUrl := fmt.Sprintf("%s/token", authDomain)
262+
func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) {
263+
// Set form-encoded data for the POST to the access token endpoint
254264
data := fmt.Sprintf(
255265
"grant_type=authorization_code&client_id=%s"+
256266
"&code_verifier=%s"+
@@ -260,7 +270,7 @@ func getUserAccessAndRefreshTokens(authDomain, clientID, codeVerifier, authoriza
260270
payload := strings.NewReader(data)
261271

262272
// Create the request and execute it
263-
req, _ := http.NewRequest("POST", authUrl, payload)
273+
req, _ := http.NewRequest("POST", idpWellKnownConfig.TokenEndpoint, payload)
264274
req.Header.Add("content-type", "application/x-www-form-urlencoded")
265275
httpClient := &http.Client{}
266276
res, err := httpClient.Do(req)
@@ -331,3 +341,48 @@ func openBrowser(pageUrl string) error {
331341
}
332342
return nil
333343
}
344+
345+
// parseWellKnownConfiguration gets the well-known OpenID configuration from the provided URL and returns it as a JSON
346+
// the method also stores the IDP token endpoint in the authentication storage
347+
func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string) (wellKnownConfig *wellKnownConfig, err error) {
348+
req, _ := http.NewRequest("GET", wellKnownConfigURL, http.NoBody)
349+
res, err := httpClient.Do(req)
350+
if err != nil {
351+
return nil, fmt.Errorf("make the request: %w", err)
352+
}
353+
354+
// Process the response
355+
defer func() {
356+
closeErr := res.Body.Close()
357+
if closeErr != nil {
358+
err = fmt.Errorf("close response body: %w", closeErr)
359+
}
360+
}()
361+
body, err := io.ReadAll(res.Body)
362+
if err != nil {
363+
return nil, fmt.Errorf("read response body: %w", err)
364+
}
365+
366+
err = json.Unmarshal(body, &wellKnownConfig)
367+
if err != nil {
368+
return nil, fmt.Errorf("unmarshal response: %w", err)
369+
}
370+
if wellKnownConfig == nil {
371+
return nil, fmt.Errorf("nil well-known configuration response")
372+
}
373+
if wellKnownConfig.Issuer == "" {
374+
return nil, fmt.Errorf("found no issuer")
375+
}
376+
if wellKnownConfig.AuthorizationEndpoint == "" {
377+
return nil, fmt.Errorf("found no authorization endpoint")
378+
}
379+
if wellKnownConfig.TokenEndpoint == "" {
380+
return nil, fmt.Errorf("found no token endpoint")
381+
}
382+
383+
err = SetAuthField(IDP_TOKEN_ENDPOINT, wellKnownConfig.TokenEndpoint)
384+
if err != nil {
385+
return nil, fmt.Errorf("set token endpoint in the authentication storage: %w", err)
386+
}
387+
return wellKnownConfig, err
388+
}

internal/pkg/auth/user_login_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package auth
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"strings"
8+
"testing"
9+
10+
"github.com/google/go-cmp/cmp"
11+
"github.com/zalando/go-keyring"
12+
)
13+
14+
type apiClientMocked struct {
15+
getFails bool
16+
getResponse string
17+
}
18+
19+
func (a *apiClientMocked) Do(_ *http.Request) (*http.Response, error) {
20+
if a.getFails {
21+
return &http.Response{
22+
StatusCode: http.StatusNotFound,
23+
}, fmt.Errorf("not found")
24+
}
25+
return &http.Response{
26+
Status: "200 OK",
27+
StatusCode: http.StatusAccepted,
28+
Body: io.NopCloser(strings.NewReader(a.getResponse)),
29+
}, nil
30+
}
31+
32+
func TestParseWellKnownConfig(t *testing.T) {
33+
tests := []struct {
34+
name string
35+
getFails bool
36+
getResponse string
37+
isValid bool
38+
expected *wellKnownConfig
39+
}{
40+
{
41+
name: "success",
42+
getFails: false,
43+
getResponse: `{"issuer":"issuer","authorization_endpoint":"auth","token_endpoint":"token"}`,
44+
isValid: true,
45+
expected: &wellKnownConfig{
46+
Issuer: "issuer",
47+
AuthorizationEndpoint: "auth",
48+
TokenEndpoint: "token",
49+
},
50+
},
51+
{
52+
name: "get_fails",
53+
getFails: true,
54+
getResponse: "",
55+
isValid: false,
56+
expected: nil,
57+
},
58+
{
59+
name: "empty_response",
60+
getFails: true,
61+
getResponse: "",
62+
isValid: false,
63+
expected: nil,
64+
},
65+
{
66+
name: "missing_issuer",
67+
getFails: true,
68+
getResponse: `{"authorization_endpoint":"auth","token_endpoint":"token"}`,
69+
isValid: false,
70+
expected: nil,
71+
},
72+
{
73+
name: "missing_authorization",
74+
getFails: true,
75+
getResponse: `{"issuer":"issuer","token_endpoint":"token"}`,
76+
isValid: false,
77+
expected: nil,
78+
},
79+
{
80+
name: "missing_token",
81+
getFails: true,
82+
getResponse: `{"issuer":"issuer","authorization_endpoint":"auth"}`,
83+
isValid: false,
84+
expected: nil,
85+
},
86+
}
87+
for _, tt := range tests {
88+
t.Run(tt.name, func(t *testing.T) {
89+
keyring.MockInit()
90+
91+
testClient := apiClientMocked{
92+
tt.getFails,
93+
tt.getResponse,
94+
}
95+
96+
got, err := parseWellKnownConfiguration(&testClient, "")
97+
98+
if tt.isValid && err != nil {
99+
t.Fatalf("expected no error, got %v", err)
100+
}
101+
if !tt.isValid && err == nil {
102+
t.Fatalf("expected error, got none")
103+
}
104+
105+
if tt.isValid && !cmp.Equal(*got, *tt.expected) {
106+
t.Fatalf("expected %v, got %v", tt.expected, got)
107+
}
108+
})
109+
}
110+
}

0 commit comments

Comments
 (0)