diff --git a/token/jwt/jwt.go b/token/jwt/jwt.go index 9c5aa5775..d2ab1d430 100644 --- a/token/jwt/jwt.go +++ b/token/jwt/jwt.go @@ -57,14 +57,14 @@ func (j *DefaultSigner) Generate(ctx context.Context, claims MapClaims, header M return generateToken(claims, header, jose.ES256, t) case jose.OpaqueSigner: switch tt := t.Public().Key.(type) { - case *rsa.PrivateKey: + case *rsa.PublicKey: alg := jose.RS256 if len(t.Algs()) > 0 { alg = t.Algs()[0] } return generateToken(claims, header, alg, t) - case *ecdsa.PrivateKey: + case *ecdsa.PublicKey: alg := jose.ES256 if len(t.Algs()) > 0 { alg = t.Algs()[0] diff --git a/token/jwt/jwt_test.go b/token/jwt/jwt_test.go index 1939d7bba..58475a27c 100644 --- a/token/jwt/jwt_test.go +++ b/token/jwt/jwt_test.go @@ -5,6 +5,11 @@ package jwt import ( "context" + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" "fmt" "strings" "testing" @@ -24,6 +29,86 @@ var header = &Headers{ }, } +type mockOpaqueSigner struct { + publicKey *jose.JSONWebKey + privateKey interface{} + signer jose.Signer +} + +func newMockOpaqueSigner(key interface{}, alg jose.SignatureAlgorithm) (*mockOpaqueSigner, error) { + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: key}, nil) + if err != nil { + return nil, err + } + + var publicKey *jose.JSONWebKey + switch k := key.(type) { + case *rsa.PrivateKey: + publicKey = &jose.JSONWebKey{ + Key: &k.PublicKey, + Algorithm: string(alg), + } + case *ecdsa.PrivateKey: + publicKey = &jose.JSONWebKey{ + Key: &k.PublicKey, + Algorithm: string(alg), + } + } + + return &mockOpaqueSigner{ + publicKey: publicKey, + privateKey: key, + signer: signer, + }, nil +} + +func (m *mockOpaqueSigner) Public() *jose.JSONWebKey { + return m.publicKey +} + +func (m *mockOpaqueSigner) Algs() []jose.SignatureAlgorithm { + if m.publicKey.Algorithm == "RS256" { + return []jose.SignatureAlgorithm{jose.RS256} + } + + if m.publicKey.Algorithm == "ES256" { + return []jose.SignatureAlgorithm{jose.ES256} + } + + return []jose.SignatureAlgorithm{} +} + +func (m *mockOpaqueSigner) SignPayload(payload []byte, alg jose.SignatureAlgorithm) ([]byte, error) { + // Use the stored private key to sign directly + switch alg { + case jose.RS256: + if rsaPrivKey, ok := m.privateKey.(*rsa.PrivateKey); ok { + // Hash the payload first for RSA signing + hash := sha256.Sum256(payload) + return rsaPrivKey.Sign(rand.Reader, hash[:], crypto.SHA256) + } + return nil, fmt.Errorf("expected RSA private key for RS256") + case jose.ES256: + if ecdsaPrivKey, ok := m.privateKey.(*ecdsa.PrivateKey); ok { + hash := sha256.Sum256(payload) + r, s, err := ecdsa.Sign(rand.Reader, ecdsaPrivKey, hash[:]) + if err != nil { + return nil, err + } + + // Convert to JWT format: R || S (32 bytes each for P-256) + keySize := 32 // P-256 uses 32-byte values + signature := make([]byte, 2*keySize) + r.FillBytes(signature[0:keySize]) + s.FillBytes(signature[keySize : 2*keySize]) + return signature, nil + } + return nil, fmt.Errorf("expected ECDSA private key for ES256") + default: + return nil, fmt.Errorf("unsupported algorithm: %s", alg) + } +} + func TestHash(t *testing.T) { for k, tc := range []struct { d string @@ -96,6 +181,17 @@ func TestGenerateJWT(t *testing.T) { key = gen.MustRSAKey() }, }, + { + d: "RS256JWTStrategy (Opaque Signer)", + strategy: &DefaultSigner{ + GetPrivateKey: func(_ context.Context) (interface{}, error) { + return newMockOpaqueSigner(key, jose.RS256) + }, + }, + resetKey: func(strategy Signer) { + key = gen.MustRSAKey() + }, + }, { d: "ES256JWTStrategy", strategy: &DefaultSigner{ @@ -112,7 +208,7 @@ func TestGenerateJWT(t *testing.T) { }, }, { - d: "ES256JWTStrategy", + d: "ES256JWTStrategy (Opaque Signer)", strategy: &DefaultSigner{ GetPrivateKey: func(_ context.Context) (interface{}, error) { return key, nil @@ -122,6 +218,17 @@ func TestGenerateJWT(t *testing.T) { key = gen.MustES256Key() }, }, + { + d: "ES256OpaqueSigner", + strategy: &DefaultSigner{ + GetPrivateKey: func(_ context.Context) (interface{}, error) { + return newMockOpaqueSigner(key, jose.ES256) + }, + }, + resetKey: func(strategy Signer) { + key = gen.MustES256Key() + }, + }, } { t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { claims := &JWTClaims{