Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pkg/auth/clientdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"time"
)

const (
Expand All @@ -32,6 +33,7 @@ var (
ErrInvalidClientData = errors.New("invalid client data")
ErrInvalidPrivateKey = errors.New("invalid private key")
ErrInvalidPublicKey = errors.New("invalid public key")
ErrClientDataExpired = errors.New("client data has expired")
)

type SignatureAlgorithm string
Expand All @@ -40,6 +42,9 @@ const (
SignatureAlgorithmRS256 SignatureAlgorithm = "RS256"
)

// To allow overriding for test purposes
var nowFunc = time.Now

type ClientData struct {
// Mandatory user attributes
Identifier string `json:"identifier"`
Expand All @@ -62,6 +67,9 @@ type ClientData struct {
// SignatureAlgorithm is the algorithm used to sign the client data.
SignatureAlgorithm SignatureAlgorithm `json:"alg"`

// CreatedAt The datetime of when the object was created (RFC3339 format)
CreatedAt time.Time `json:"createdAt"`

b64data string
}

Expand All @@ -86,6 +94,10 @@ func DecodeFrom(b64data string) (*ClientData, error) {

// Verify verifies the signature of the client data using the provided public key.
func (c *ClientData) Verify(publicKey any, b64sig string) error {
if nowFunc().After(c.CreatedAt.Add(time.Minute)) {
return ErrClientDataExpired
}

switch c.SignatureAlgorithm {
case SignatureAlgorithmRS256:
signature, err := base64.RawURLEncoding.DecodeString(b64sig)
Expand All @@ -109,6 +121,8 @@ func (c *ClientData) Verify(publicKey any, b64sig string) error {
// Encode encodes the client data and signs it using the provided private key.
// Both values are returned as base64 URL encoded strings.
func (c *ClientData) Encode(privateKey any) (string, string, error) {
c.CreatedAt = nowFunc()

jsonString, err := json.Marshal(c)
if err != nil {
return "", "", err
Expand Down
64 changes: 41 additions & 23 deletions pkg/auth/clientdata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rsa"
"reflect"
"testing"
"time"

"github.com/openkcm/common-sdk/pkg/auth"
)
Expand Down Expand Up @@ -39,36 +40,51 @@ func TestEndToEnd(t *testing.T) {
SignatureAlgorithm: auth.SignatureAlgorithmRS256,
}

expiredClientData := defClientData
expiredClientData.CreatedAt = time.Now().Add(time.Hour)

// create the test cases
tests := []struct {
name string
clientData *auth.ClientData
privateKey any
publicKey any
wantError bool
wantError2 bool
wantError3 bool
name string
clientData *auth.ClientData
privateKey any
publicKey any
wantError bool
wantError2 bool
wantError3 bool
postDecodeNowFunc func() time.Time
}{
{
name: "invalid signature algorithm",
clientData: &auth.ClientData{},
wantError: true,
name: "invalid signature algorithm",
clientData: &auth.ClientData{},
wantError: true,
postDecodeNowFunc: time.Now,
}, {
name: "invalid private key",
clientData: defClientData,
privateKey: "not a private key",
wantError: true,
postDecodeNowFunc: time.Now,
}, {
name: "invalid private key",
clientData: defClientData,
privateKey: "not a private key",
wantError: true,
name: "invalid public key",
clientData: defClientData,
privateKey: rsaPrivateKey,
publicKey: "not a public key",
wantError3: true,
postDecodeNowFunc: time.Now,
}, {
name: "invalid public key",
clientData: defClientData,
privateKey: rsaPrivateKey,
publicKey: "not a public key",
wantError3: true,
name: "expired",
clientData: expiredClientData,
privateKey: rsaPrivateKey,
publicKey: rsaPublicKey,
wantError3: true,
postDecodeNowFunc: func() time.Time { return time.Now().Add(time.Second * 61) },
}, {
name: "ok",
clientData: defClientData,
privateKey: rsaPrivateKey,
publicKey: rsaPublicKey,
name: "ok",
clientData: defClientData,
privateKey: rsaPrivateKey,
publicKey: rsaPublicKey,
postDecodeNowFunc: time.Now,
},
}

Expand Down Expand Up @@ -103,6 +119,8 @@ func TestEndToEnd(t *testing.T) {
t.Error("client data does not match")
}

auth.SetNowFunc(tc.postDecodeNowFunc)

// Act Verify
err3 := clientData.Verify(tc.publicKey, b64sig)

Expand Down
7 changes: 7 additions & 0 deletions pkg/auth/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package auth

import "time"

func SetNowFunc(newNowFunc func() time.Time) {
nowFunc = newNowFunc
}
Loading