Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion evidence.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ package psatoken

import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
_ "crypto/sha256" // used hash algorithms need to be imported explicitly
"errors"
"fmt"
Expand Down Expand Up @@ -147,7 +150,17 @@ func (e *Evidence) Verify(pk crypto.PublicKey) error {

algo, err := protected.Algorithm()
if err != nil {
return fmt.Errorf("unable to get verification algorithm: %w", err)
// If algorithm is not found in protected headers, this might be a token
// created by compile_token or other tools that don't set the algorithm
// in protected headers. In this case, we should try to infer the algorithm
// from the public key type or use a default algorithm.

// Try to infer algorithm from public key type
inferredAlgo, inferErr := inferAlgorithmFromPublicKey(pk)
if inferErr != nil {
return fmt.Errorf("unable to get verification algorithm: %w (and failed to infer: %v)", err, inferErr)
}
algo = inferredAlgo
}

verifier, err := cose.NewVerifier(algo, pk)
Expand Down Expand Up @@ -185,6 +198,40 @@ func (e *Evidence) doSign(signer cose.Signer) ([]byte, error) {
return wrap, nil
}

// InferAlgorithmFromPublicKey attempts to infer the COSE algorithm from the public key type.
// This is used as a fallback when the algorithm is not present in the protected headers.
// This function is exported for testing purposes.
func InferAlgorithmFromPublicKey(pk crypto.PublicKey) (cose.Algorithm, error) {
return inferAlgorithmFromPublicKey(pk)
}

// inferAlgorithmFromPublicKey is the internal implementation of algorithm inference.
func inferAlgorithmFromPublicKey(pk crypto.PublicKey) (cose.Algorithm, error) {
switch key := pk.(type) {
case *ecdsa.PublicKey:
// For ECDSA keys, try to determine the curve and map to appropriate COSE algorithm
switch key.Curve.Params().BitSize {
case 256:
return cose.AlgorithmES256, nil
case 384:
return cose.AlgorithmES384, nil
case 521:
return cose.AlgorithmES512, nil
default:
return cose.AlgorithmES256, nil // Default to ES256 for unknown curves
}
case ed25519.PublicKey:
return cose.AlgorithmEdDSA, nil
case *rsa.PublicKey:
// Default to PS256 for RSA keys
return cose.AlgorithmPS256, nil
default:
// If we can't determine the key type, default to ES256 as it's most common
// for PSA tokens
return cose.AlgorithmES256, nil
}
}

// MarshalJSON encodes the PSA claims-set to JSON
func (e *Evidence) MarshalJSON() ([]byte, error) {
return EncodeClaimsToJSON(e.Claims)
Expand Down
51 changes: 51 additions & 0 deletions evidence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,54 @@ func TestEvidence_SignUnvalidated(t *testing.T) {
err = EvidenceOut.Verify(pk)
assert.EqualError(t, err, "signature verification failed: verification error")
}

// TestIssue18AlgorithmInference tests the fix for issue #18
// This test verifies that the algorithm inference functionality works correctly
// when the algorithm is not present in the protected headers (e.g., from compile_token)
func TestIssue18AlgorithmInference(t *testing.T) {
// Test the algorithm inference function with different key types
testCases := []struct {
name string
keyData string
expectedAlg string
}{
{
name: "ECDSA P-256 key",
keyData: testECKeyA,
expectedAlg: "ES256",
},
{
name: "TFM ECDSA key",
keyData: testTFMECKey,
expectedAlg: "ES256",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pk := pubKeyFromJWK(t, tc.keyData)

inferredAlg, err := InferAlgorithmFromPublicKey(pk)
assert.NoError(t, err, "Should be able to infer algorithm")
assert.Equal(t, tc.expectedAlg, inferredAlg.String(), "Should infer correct algorithm")
})
}

// Test that verification still works normally (regression test)
tokenSigner := signerFromJWK(t, testECKeyA)
claims := mustBuildValidP2Claims(t, false)

var evidence Evidence
err := evidence.SetClaims(claims)
require.NoError(t, err)

cwt, err := evidence.ValidateAndSign(tokenSigner)
require.NoError(t, err)

evidenceOut, err := DecodeAndValidateEvidenceFromCOSE(cwt)
require.NoError(t, err)

pk := pubKeyFromJWK(t, testECKeyA)
err = evidenceOut.Verify(pk)
assert.NoError(t, err, "Normal verification should still work after the fix")
}