Skip to content

Commit e5035e8

Browse files
committed
refactor: merge user and device code storage
1 parent 9646fe3 commit e5035e8

8 files changed

+130
-219
lines changed

handler/rfc8628/auth_handler.go

+12-28
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,7 @@ type DeviceAuthHandler struct {
3030
func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar fosite.DeviceRequester, resp fosite.DeviceResponder) error {
3131
var err error
3232

33-
var deviceCode string
34-
deviceCode, err = d.handleDeviceCode(ctx, dar)
35-
if err != nil {
36-
return err
37-
}
38-
39-
var userCode string
40-
userCode, err = d.handleUserCode(ctx, dar)
33+
deviceCode, userCode, err := d.handleDeviceAuthSession(ctx, dar)
4134
if err != nil {
4235
return err
4336
}
@@ -52,41 +45,32 @@ func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar
5245
return nil
5346
}
5447

55-
func (d *DeviceAuthHandler) handleDeviceCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) {
56-
code, signature, err := d.Strategy.GenerateDeviceCode(ctx)
57-
if err != nil {
58-
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
59-
}
48+
func (d *DeviceAuthHandler) handleDeviceAuthSession(ctx context.Context, dar fosite.DeviceRequester) (string, string, error) {
49+
var userCode, userCodeSignature string
6050

61-
dar.GetSession().SetExpiresAt(fosite.DeviceCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)))
62-
if err = d.Storage.CreateDeviceCodeSession(ctx, signature, dar.Sanitize(nil)); err != nil {
63-
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
51+
deviceCode, deviceCodeSignature, err := d.Strategy.GenerateDeviceCode(ctx)
52+
if err != nil {
53+
return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
6454
}
6555

66-
return code, nil
67-
}
68-
69-
func (d *DeviceAuthHandler) handleUserCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) {
70-
var err error
71-
var userCode, signature string
56+
dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second))
7257
// Note: the retries are added here because we need to ensure uniqueness of user codes.
7358
// The chances of duplicates should however be diminishing, because they are the same
7459
// chance an attacker will be able to hit a valid code with few guesses. However, as
7560
// used codes will probably still be around for some time before they get cleaned,
7661
// the chances of hitting a duplicate here can be higher.
7762
// Three retries should be plenty, as otherwise the entropy is definitely off.
7863
for i := 0; i < MaxAttempts; i++ {
79-
userCode, signature, err = d.Strategy.GenerateUserCode(ctx)
64+
userCode, userCodeSignature, err = d.Strategy.GenerateUserCode(ctx)
8065
if err != nil {
81-
return "", err
66+
return "", "", err
8267
}
8368

84-
dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second))
85-
if err = d.Storage.CreateUserCodeSession(ctx, signature, dar.Sanitize(nil)); err == nil {
86-
return userCode, nil
69+
if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil)); err == nil {
70+
return deviceCode, userCode, nil
8771
}
8872
}
8973

9074
errMsg := fmt.Sprintf("Exceeded user-code generation max attempts %v: %s", MaxAttempts, err.Error())
91-
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(errMsg))
75+
return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(errMsg))
9276
}

handler/rfc8628/auth_handler_test.go

+6-19
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,15 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) {
8585
EXPECT().
8686
GenerateDeviceCode(ctx).
8787
Return("deviceCode", "signature", nil)
88-
mockRFC8628CoreStorage.
89-
EXPECT().
90-
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
91-
Return(nil)
9288
mockRFC8628CodeStrategy.
9389
EXPECT().
9490
GenerateUserCode(ctx).
95-
Return("userCode", "signature", nil).
91+
Return("userCode", "signature2", nil).
9692
Times(1)
9793
mockRFC8628CoreStorage.
9894
EXPECT().
99-
CreateUserCodeSession(ctx, "signature", gomock.Any()).
100-
Return(nil).
101-
Times(1)
95+
CreateDeviceAuthSession(ctx, "signature", "signature2", gomock.Any()).
96+
Return(nil)
10297
},
10398
check: func(t *testing.T, resp *fosite.DeviceResponse) {
10499
assert.Equal(t, "userCode", resp.GetUserCode())
@@ -111,26 +106,22 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) {
111106
EXPECT().
112107
GenerateDeviceCode(ctx).
113108
Return("deviceCode", "signature", nil)
114-
mockRFC8628CoreStorage.
115-
EXPECT().
116-
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
117-
Return(nil)
118109
gomock.InOrder(
119110
mockRFC8628CodeStrategy.
120111
EXPECT().
121112
GenerateUserCode(ctx).
122113
Return("duplicatedUserCode", "duplicatedSignature", nil),
123114
mockRFC8628CoreStorage.
124115
EXPECT().
125-
CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()).
116+
CreateDeviceAuthSession(ctx, "signature", "duplicatedSignature", gomock.Any()).
126117
Return(errors.New("unique constraint violation")),
127118
mockRFC8628CodeStrategy.
128119
EXPECT().
129120
GenerateUserCode(ctx).
130121
Return("uniqueUserCode", "uniqueSignature", nil),
131122
mockRFC8628CoreStorage.
132123
EXPECT().
133-
CreateUserCodeSession(ctx, "uniqueSignature", gomock.Any()).
124+
CreateDeviceAuthSession(ctx, "signature", "uniqueSignature", gomock.Any()).
134125
Return(nil),
135126
)
136127
},
@@ -145,18 +136,14 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) {
145136
EXPECT().
146137
GenerateDeviceCode(ctx).
147138
Return("deviceCode", "signature", nil)
148-
mockRFC8628CoreStorage.
149-
EXPECT().
150-
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
151-
Return(nil)
152139
mockRFC8628CodeStrategy.
153140
EXPECT().
154141
GenerateUserCode(ctx).
155142
Return("duplicatedUserCode", "duplicatedSignature", nil).
156143
Times(rfc8628.MaxAttempts)
157144
mockRFC8628CoreStorage.
158145
EXPECT().
159-
CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()).
146+
CreateDeviceAuthSession(ctx, "signature", "duplicatedSignature", gomock.Any()).
160147
Return(errors.New("unique constraint violation")).
161148
Times(rfc8628.MaxAttempts)
162149
},

handler/rfc8628/storage.go

+6-25
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@ import (
1212

1313
// RFC8628CoreStorage is the storage needed for the DeviceAuthHandler
1414
type RFC8628CoreStorage interface {
15-
DeviceCodeStorage
16-
UserCodeStorage
15+
DeviceAuthStorage
1716
oauth2.AccessTokenStorage
1817
oauth2.RefreshTokenStorage
1918
}
2019

21-
// DeviceCodeStorage handles the device_code storage
22-
type DeviceCodeStorage interface {
23-
// CreateDeviceCodeSession stores the device request for a given device code.
24-
CreateDeviceCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error)
20+
// DeviceAuthStorage handles the device auth session storage
21+
type DeviceAuthStorage interface {
22+
// CreateDeviceAuthSession stores the device auth request session.
23+
CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.Requester) (err error)
2524

2625
// GetDeviceCodeSession hydrates the session based on the given device code and returns the device request.
2726
// If the device code has been invalidated with `InvalidateDeviceCodeSession`, this
@@ -30,26 +29,8 @@ type DeviceCodeStorage interface {
3029
// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedDeviceCode error!
3130
GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error)
3231

33-
// InvalidateDeviceCodeSession is called when a device code is being used. The state of the user
32+
// InvalidateDeviceCodeSession is called when a device code is being used. The state of the device
3433
// code should be set to invalid and consecutive requests to GetDeviceCodeSession should return the
3534
// ErrInvalidatedDeviceCode error.
3635
InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error)
3736
}
38-
39-
// UserCodeStorage handles the user_code storage
40-
type UserCodeStorage interface {
41-
// CreateUserCodeSession stores the device request for a given user code.
42-
CreateUserCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error)
43-
44-
// GetUserCodeSession hydrates the session based on the given user code and returns the device request.
45-
// If the user code has been invalidated with `InvalidateUserCodeSession`, this
46-
// method should return the ErrInvalidatedUserCode error.
47-
//
48-
// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedUserCode error!
49-
GetUserCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error)
50-
51-
// InvalidateUserCodeSession is called when a user code is being used. The state of the user
52-
// code should be set to invalid and consecutive requests to GetUserCodeSession should return the
53-
// ErrInvalidatedUserCode error.
54-
InvalidateUserCodeSession(ctx context.Context, signature string) (err error)
55-
}

handler/rfc8628/token_handler_test.go

+24-8
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
154154
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
155155
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
156156
require.NoError(t, err)
157+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
158+
require.NoError(t, err)
157159
areq.Form.Add("device_code", code)
158160

159-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
161+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
160162
},
161163
expectErr: fosite.ErrAuthorizationPending,
162164
},
@@ -192,9 +194,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
192194
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
193195
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
194196
require.NoError(t, err)
197+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
198+
require.NoError(t, err)
195199
areq.Form.Add("device_code", code)
196200

197-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
201+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
198202
},
199203
expectErr: fosite.ErrDeviceExpiredToken,
200204
},
@@ -227,9 +231,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
227231
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
228232
token, signature, err := strategy.GenerateDeviceCode(context.TODO())
229233
require.NoError(t, err)
234+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
235+
require.NoError(t, err)
230236
areq.Form = url.Values{"device_code": {token}}
231237

232-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
238+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
233239
},
234240
expectErr: fosite.ErrInvalidGrant,
235241
},
@@ -263,9 +269,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
263269
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
264270
token, signature, err := strategy.GenerateDeviceCode(context.TODO())
265271
require.NoError(t, err)
272+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
273+
require.NoError(t, err)
266274

267275
areq.Form = url.Values{"device_code": {token}}
268-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
276+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
269277
},
270278
},
271279
}
@@ -342,9 +350,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest_RateLimiting(t *testing.T) {
342350

343351
token, signature, err := strategy.GenerateDeviceCode(context.TODO())
344352
require.NoError(t, err)
353+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
354+
require.NoError(t, err)
345355

346356
areq.Form = url.Values{"device_code": {token}}
347-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
357+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
348358
err = h.HandleTokenEndpointRequest(context.Background(), areq)
349359
require.NoError(t, err, "%+v", err)
350360
err = h.HandleTokenEndpointRequest(context.Background(), areq)
@@ -441,9 +451,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
441451
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, _ *fosite.Config) {
442452
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
443453
require.NoError(t, err)
454+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
455+
require.NoError(t, err)
444456
areq.Form.Add("device_code", code)
445457

446-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
458+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
447459
},
448460
check: func(t *testing.T, aresp *fosite.AccessResponse) {
449461
assert.NotEmpty(t, aresp.AccessToken)
@@ -483,9 +495,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
483495
config.RefreshTokenScopes = []string{}
484496
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
485497
require.NoError(t, err)
498+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
499+
require.NoError(t, err)
486500
areq.Form.Add("device_code", code)
487501

488-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
502+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
489503
},
490504
check: func(t *testing.T, aresp *fosite.AccessResponse) {
491505
assert.NotEmpty(t, aresp.AccessToken)
@@ -524,9 +538,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
524538
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, config *fosite.Config) {
525539
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
526540
require.NoError(t, err)
541+
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
542+
require.NoError(t, err)
527543
areq.Form.Add("device_code", code)
528544

529-
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
545+
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
530546
},
531547
check: func(t *testing.T, aresp *fosite.AccessResponse) {
532548
assert.NotEmpty(t, aresp.AccessToken)

integration/helper_setup_test.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ var fositeStore = &storage.MemoryStore{
123123
AccessTokenRequestIDs: map[string]string{},
124124
RefreshTokenRequestIDs: map[string]string{},
125125
PARSessions: map[string]fosite.AuthorizeRequester{},
126-
DeviceCodes: map[string]fosite.Requester{},
127-
UserCodes: map[string]fosite.Requester{},
128-
DeviceCodesRequestIDs: map[string]string{},
126+
DeviceAuths: map[string]fosite.Requester{},
127+
DeviceCodesRequestIDs: map[string]storage.DeviceAuthPair{},
129128
UserCodesRequestIDs: map[string]string{},
130129
}
131130

0 commit comments

Comments
 (0)