diff --git a/pkg/auth/clientdata.go b/pkg/auth/clientdata.go index 727eef4..d2bdbee 100644 --- a/pkg/auth/clientdata.go +++ b/pkg/auth/clientdata.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "time" ) const ( @@ -32,6 +33,7 @@ var ( ErrInvalidClientData = errors.New("invalid client data") ErrInvalidPrivateKey = errors.New("invalid private key") ErrInvalidPublicKey = errors.New("invalid public key") + ErrClientDataExpired = errors.New("client data has expired") ) type SignatureAlgorithm string @@ -40,6 +42,9 @@ const ( SignatureAlgorithmRS256 SignatureAlgorithm = "RS256" ) +// To allow overriding for test purposes +var nowFunc = time.Now + type ClientData struct { // Mandatory user attributes Identifier string `json:"identifier"` @@ -62,6 +67,9 @@ type ClientData struct { // SignatureAlgorithm is the algorithm used to sign the client data. SignatureAlgorithm SignatureAlgorithm `json:"alg"` + // CreatedAt The datetime of when the object was created (RFC3339 format) + CreatedAt time.Time `json:"createdAt"` + b64data string } @@ -86,6 +94,10 @@ func DecodeFrom(b64data string) (*ClientData, error) { // Verify verifies the signature of the client data using the provided public key. func (c *ClientData) Verify(publicKey any, b64sig string) error { + if nowFunc().After(c.CreatedAt.Add(time.Minute)) { + return ErrClientDataExpired + } + switch c.SignatureAlgorithm { case SignatureAlgorithmRS256: signature, err := base64.RawURLEncoding.DecodeString(b64sig) @@ -109,6 +121,8 @@ func (c *ClientData) Verify(publicKey any, b64sig string) error { // Encode encodes the client data and signs it using the provided private key. // Both values are returned as base64 URL encoded strings. func (c *ClientData) Encode(privateKey any) (string, string, error) { + c.CreatedAt = nowFunc() + jsonString, err := json.Marshal(c) if err != nil { return "", "", err diff --git a/pkg/auth/clientdata_test.go b/pkg/auth/clientdata_test.go index 2708134..3798155 100644 --- a/pkg/auth/clientdata_test.go +++ b/pkg/auth/clientdata_test.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "reflect" "testing" + "time" "github.com/openkcm/common-sdk/pkg/auth" ) @@ -39,36 +40,51 @@ func TestEndToEnd(t *testing.T) { SignatureAlgorithm: auth.SignatureAlgorithmRS256, } + expiredClientData := defClientData + expiredClientData.CreatedAt = time.Now().Add(time.Hour) + // create the test cases tests := []struct { - name string - clientData *auth.ClientData - privateKey any - publicKey any - wantError bool - wantError2 bool - wantError3 bool + name string + clientData *auth.ClientData + privateKey any + publicKey any + wantError bool + wantError2 bool + wantError3 bool + postDecodeNowFunc func() time.Time }{ { - name: "invalid signature algorithm", - clientData: &auth.ClientData{}, - wantError: true, + name: "invalid signature algorithm", + clientData: &auth.ClientData{}, + wantError: true, + postDecodeNowFunc: time.Now, + }, { + name: "invalid private key", + clientData: defClientData, + privateKey: "not a private key", + wantError: true, + postDecodeNowFunc: time.Now, }, { - name: "invalid private key", - clientData: defClientData, - privateKey: "not a private key", - wantError: true, + name: "invalid public key", + clientData: defClientData, + privateKey: rsaPrivateKey, + publicKey: "not a public key", + wantError3: true, + postDecodeNowFunc: time.Now, }, { - name: "invalid public key", - clientData: defClientData, - privateKey: rsaPrivateKey, - publicKey: "not a public key", - wantError3: true, + name: "expired", + clientData: expiredClientData, + privateKey: rsaPrivateKey, + publicKey: rsaPublicKey, + wantError3: true, + postDecodeNowFunc: func() time.Time { return time.Now().Add(time.Second * 61) }, }, { - name: "ok", - clientData: defClientData, - privateKey: rsaPrivateKey, - publicKey: rsaPublicKey, + name: "ok", + clientData: defClientData, + privateKey: rsaPrivateKey, + publicKey: rsaPublicKey, + postDecodeNowFunc: time.Now, }, } @@ -103,6 +119,8 @@ func TestEndToEnd(t *testing.T) { t.Error("client data does not match") } + auth.SetNowFunc(tc.postDecodeNowFunc) + // Act Verify err3 := clientData.Verify(tc.publicKey, b64sig) diff --git a/pkg/auth/export_test.go b/pkg/auth/export_test.go new file mode 100644 index 0000000..664c04d --- /dev/null +++ b/pkg/auth/export_test.go @@ -0,0 +1,7 @@ +package auth + +import "time" + +func SetNowFunc(newNowFunc func() time.Time) { + nowFunc = newNowFunc +}