Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions token/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ func (j *DefaultSigner) Generate(ctx context.Context, claims MapClaims, header M
return generateToken(claims, header, jose.ES256, t)
case jose.OpaqueSigner:
switch tt := t.Public().Key.(type) {
case *rsa.PrivateKey:
case *rsa.PublicKey:
alg := jose.RS256
if len(t.Algs()) > 0 {
alg = t.Algs()[0]
}

return generateToken(claims, header, alg, t)
case *ecdsa.PrivateKey:
case *ecdsa.PublicKey:
alg := jose.ES256
if len(t.Algs()) > 0 {
alg = t.Algs()[0]
Expand Down
109 changes: 108 additions & 1 deletion token/jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ package jwt

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"fmt"
"strings"
"testing"
Expand All @@ -24,6 +29,86 @@ var header = &Headers{
},
}

type mockOpaqueSigner struct {
publicKey *jose.JSONWebKey
privateKey interface{}
signer jose.Signer
}

func newMockOpaqueSigner(key interface{}, alg jose.SignatureAlgorithm) (*mockOpaqueSigner, error) {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: key}, nil)
if err != nil {
return nil, err
}

var publicKey *jose.JSONWebKey
switch k := key.(type) {
case *rsa.PrivateKey:
publicKey = &jose.JSONWebKey{
Key: &k.PublicKey,
Algorithm: string(alg),
}
case *ecdsa.PrivateKey:
publicKey = &jose.JSONWebKey{
Key: &k.PublicKey,
Algorithm: string(alg),
}
}

return &mockOpaqueSigner{
publicKey: publicKey,
privateKey: key,
signer: signer,
}, nil
}

func (m *mockOpaqueSigner) Public() *jose.JSONWebKey {
return m.publicKey
}

func (m *mockOpaqueSigner) Algs() []jose.SignatureAlgorithm {
if m.publicKey.Algorithm == "RS256" {
return []jose.SignatureAlgorithm{jose.RS256}
}

if m.publicKey.Algorithm == "ES256" {
return []jose.SignatureAlgorithm{jose.ES256}
}

return []jose.SignatureAlgorithm{}
}

func (m *mockOpaqueSigner) SignPayload(payload []byte, alg jose.SignatureAlgorithm) ([]byte, error) {
// Use the stored private key to sign directly
switch alg {
case jose.RS256:
if rsaPrivKey, ok := m.privateKey.(*rsa.PrivateKey); ok {
// Hash the payload first for RSA signing
hash := sha256.Sum256(payload)
return rsaPrivKey.Sign(rand.Reader, hash[:], crypto.SHA256)
}
return nil, fmt.Errorf("expected RSA private key for RS256")
case jose.ES256:
if ecdsaPrivKey, ok := m.privateKey.(*ecdsa.PrivateKey); ok {
hash := sha256.Sum256(payload)
r, s, err := ecdsa.Sign(rand.Reader, ecdsaPrivKey, hash[:])
if err != nil {
return nil, err
}

// Convert to JWT format: R || S (32 bytes each for P-256)
keySize := 32 // P-256 uses 32-byte values
signature := make([]byte, 2*keySize)
r.FillBytes(signature[0:keySize])
s.FillBytes(signature[keySize : 2*keySize])
return signature, nil
}
return nil, fmt.Errorf("expected ECDSA private key for ES256")
default:
return nil, fmt.Errorf("unsupported algorithm: %s", alg)
}
}

func TestHash(t *testing.T) {
for k, tc := range []struct {
d string
Expand Down Expand Up @@ -96,6 +181,17 @@ func TestGenerateJWT(t *testing.T) {
key = gen.MustRSAKey()
},
},
{
d: "RS256JWTStrategy (Opaque Signer)",
strategy: &DefaultSigner{
GetPrivateKey: func(_ context.Context) (interface{}, error) {
return newMockOpaqueSigner(key, jose.RS256)
},
},
resetKey: func(strategy Signer) {
key = gen.MustRSAKey()
},
},
{
d: "ES256JWTStrategy",
strategy: &DefaultSigner{
Expand All @@ -112,7 +208,7 @@ func TestGenerateJWT(t *testing.T) {
},
},
{
d: "ES256JWTStrategy",
d: "ES256JWTStrategy (Opaque Signer)",
strategy: &DefaultSigner{
GetPrivateKey: func(_ context.Context) (interface{}, error) {
return key, nil
Expand All @@ -122,6 +218,17 @@ func TestGenerateJWT(t *testing.T) {
key = gen.MustES256Key()
},
},
{
d: "ES256OpaqueSigner",
strategy: &DefaultSigner{
GetPrivateKey: func(_ context.Context) (interface{}, error) {
return newMockOpaqueSigner(key, jose.ES256)
},
},
resetKey: func(strategy Signer) {
key = gen.MustES256Key()
},
},
} {
t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) {
claims := &JWTClaims{
Expand Down