Skip to content

Commit 41c9937

Browse files
authored
Implement user confirmation for re-authenticating (#249)
* implement user confirmation for re-authenticating * implement PromtForEnter * update PromptForEnter * add dependencies * rollback PromptForEnter function * add comment to condition * update comment
1 parent 530511e commit 41c9937

File tree

7 files changed

+38
-10
lines changed

7 files changed

+38
-10
lines changed

internal/cmd/auth/login/login.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func NewCmd(p *print.Printer) *cobra.Command {
2323
"$ stackit auth login"),
2424
),
2525
RunE: func(cmd *cobra.Command, args []string) error {
26-
err := auth.AuthorizeUser()
26+
err := auth.AuthorizeUser(p, false)
2727
if err != nil {
2828
return fmt.Errorf("authorization failed: %w", err)
2929
}

internal/pkg/auth/auth.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ type tokenClaims struct {
2222
// It returns the configuration option that can be used to create an authenticated SDK client.
2323
//
2424
// If the user was logged in and the user session expired, reauthorizeUserRoutine is called to reauthenticate the user again.
25-
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func() error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
25+
func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func(p *print.Printer, isReauthentication bool) error) (authCfgOption sdkConfig.ConfigurationOption, err error) {
2626
flow, err := GetAuthFlow()
2727
if err != nil {
2828
return nil, fmt.Errorf("get authentication flow: %w", err)
@@ -57,8 +57,7 @@ func AuthenticationConfig(p *print.Printer, reauthorizeUserRoutine func() error)
5757
authCfgOption = sdkConfig.WithCustomAuth(keyFlow)
5858
case AUTH_FLOW_USER_TOKEN:
5959
if userSessionExpired {
60-
p.Warn("Session expired, logging in again...\n")
61-
err = reauthorizeUserRoutine()
60+
err = reauthorizeUserRoutine(p, true)
6261
if err != nil {
6362
return nil, fmt.Errorf("user login: %w", err)
6463
}

internal/pkg/auth/auth_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func TestAuthenticationConfig(t *testing.T) {
192192
}
193193

194194
reauthorizeUserCalled := false
195-
reauthenticateUser := func() error {
195+
reauthenticateUser := func(p *print.Printer, isReauthentication bool) error {
196196
if reauthorizeUserCalled {
197197
t.Errorf("user reauthorized more than once")
198198
}

internal/pkg/auth/user_login.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import (
1717
"time"
1818

1919
"golang.org/x/oauth2"
20+
21+
"github.com/stackitcloud/stackit-cli/internal/pkg/print"
2022
)
2123

2224
const (
@@ -36,7 +38,14 @@ type User struct {
3638
}
3739

3840
// AuthorizeUser implements the PKCE OAuth2 flow.
39-
func AuthorizeUser() error {
41+
func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
42+
if isReauthentication {
43+
err := p.PromptForEnter("Your session has expired, press Enter to login again...")
44+
if err != nil {
45+
return err
46+
}
47+
}
48+
4049
listener, err := net.Listen("tcp", ":0")
4150
if err != nil {
4251
return fmt.Errorf("bind port for login redirect: %w", err)

internal/pkg/auth/user_token_flow.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
type userTokenFlow struct {
1616
printer *print.Printer
17-
reauthorizeUserRoutine func() error // Called if the user needs to login again
17+
reauthorizeUserRoutine func(p *print.Printer, isReauthentication bool) error // Called if the user needs to login again
1818
client *http.Client
1919
authFlow AuthFlow
2020
accessToken string
@@ -59,7 +59,6 @@ func (utf *userTokenFlow) RoundTrip(req *http.Request) (*http.Response, error) {
5959
}
6060

6161
if !accessTokenValid {
62-
utf.printer.Warn("Session expired, logging in again...")
6362
err = reauthenticateUser(utf)
6463
if err != nil {
6564
return nil, fmt.Errorf("reauthenticate user: %w", err)
@@ -91,7 +90,7 @@ func loadVarsFromStorage(utf *userTokenFlow) error {
9190
}
9291

9392
func reauthenticateUser(utf *userTokenFlow) error {
94-
err := utf.reauthorizeUserRoutine()
93+
err := utf.reauthorizeUserRoutine(utf.printer, true)
9594
if err != nil {
9695
return fmt.Errorf("authenticate user: %w", err)
9796
}

internal/pkg/auth/user_token_flow_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func TestRoundTrip(t *testing.T) {
271271
authorizeUserCalled: &authorizeUserCalled,
272272
tokensRefreshed: &tokensRefreshed,
273273
}
274-
authorizeUserRoutine := func() error {
274+
authorizeUserRoutine := func(p *print.Printer, isReauthentication bool) error {
275275
return reauthorizeUser(authorizeUserContext)
276276
}
277277

internal/pkg/print/print.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bufio"
55
"errors"
66
"fmt"
7+
78
"log/slog"
89
"os"
910
"os/exec"
@@ -130,6 +131,26 @@ func (p *Printer) PromptForConfirmation(prompt string) error {
130131
return fmt.Errorf("max number of wrong inputs")
131132
}
132133

134+
// Prompts the user for confirmation by pressing Enter.
135+
//
136+
// Returns nil only if the user (explicitly) press directly enter.
137+
// Returns ErrAborted if the user press anything else before pressing enter.
138+
func (p *Printer) PromptForEnter(prompt string) error {
139+
reader := bufio.NewReaderSize(p.Cmd.InOrStdin(), 1)
140+
141+
p.Cmd.PrintErr(prompt)
142+
answer, err := reader.ReadByte()
143+
if err != nil {
144+
return fmt.Errorf("read user response: %w", err)
145+
}
146+
147+
// ASCII code for Enter (newline) is 10.
148+
if answer == 10 {
149+
return nil
150+
}
151+
return errAborted
152+
}
153+
133154
// Shows the content in the command's stdout using the "less" command
134155
// If output format is set to none, it does nothing
135156
func (p *Printer) PagerDisplay(content string) error {

0 commit comments

Comments
 (0)