diff --git a/evidence.go b/evidence.go index 9f68a6c..cc5d3a5 100644 --- a/evidence.go +++ b/evidence.go @@ -5,7 +5,10 @@ package psatoken import ( "crypto" + "crypto/ecdsa" + "crypto/ed25519" "crypto/rand" + "crypto/rsa" _ "crypto/sha256" // used hash algorithms need to be imported explicitly "errors" "fmt" @@ -147,7 +150,17 @@ func (e *Evidence) Verify(pk crypto.PublicKey) error { algo, err := protected.Algorithm() if err != nil { - return fmt.Errorf("unable to get verification algorithm: %w", err) + // If algorithm is not found in protected headers, this might be a token + // created by compile_token or other tools that don't set the algorithm + // in protected headers. In this case, we should try to infer the algorithm + // from the public key type or use a default algorithm. + + // Try to infer algorithm from public key type + inferredAlgo, inferErr := inferAlgorithmFromPublicKey(pk) + if inferErr != nil { + return fmt.Errorf("unable to get verification algorithm: %w (and failed to infer: %v)", err, inferErr) + } + algo = inferredAlgo } verifier, err := cose.NewVerifier(algo, pk) @@ -185,6 +198,40 @@ func (e *Evidence) doSign(signer cose.Signer) ([]byte, error) { return wrap, nil } +// InferAlgorithmFromPublicKey attempts to infer the COSE algorithm from the public key type. +// This is used as a fallback when the algorithm is not present in the protected headers. +// This function is exported for testing purposes. +func InferAlgorithmFromPublicKey(pk crypto.PublicKey) (cose.Algorithm, error) { + return inferAlgorithmFromPublicKey(pk) +} + +// inferAlgorithmFromPublicKey is the internal implementation of algorithm inference. +func inferAlgorithmFromPublicKey(pk crypto.PublicKey) (cose.Algorithm, error) { + switch key := pk.(type) { + case *ecdsa.PublicKey: + // For ECDSA keys, try to determine the curve and map to appropriate COSE algorithm + switch key.Curve.Params().BitSize { + case 256: + return cose.AlgorithmES256, nil + case 384: + return cose.AlgorithmES384, nil + case 521: + return cose.AlgorithmES512, nil + default: + return cose.AlgorithmES256, nil // Default to ES256 for unknown curves + } + case ed25519.PublicKey: + return cose.AlgorithmEdDSA, nil + case *rsa.PublicKey: + // Default to PS256 for RSA keys + return cose.AlgorithmPS256, nil + default: + // If we can't determine the key type, default to ES256 as it's most common + // for PSA tokens + return cose.AlgorithmES256, nil + } +} + // MarshalJSON encodes the PSA claims-set to JSON func (e *Evidence) MarshalJSON() ([]byte, error) { return EncodeClaimsToJSON(e.Claims) diff --git a/evidence_test.go b/evidence_test.go index 68c9a1b..74a9cfc 100644 --- a/evidence_test.go +++ b/evidence_test.go @@ -339,3 +339,54 @@ func TestEvidence_SignUnvalidated(t *testing.T) { err = EvidenceOut.Verify(pk) assert.EqualError(t, err, "signature verification failed: verification error") } + +// TestIssue18AlgorithmInference tests the fix for issue #18 +// This test verifies that the algorithm inference functionality works correctly +// when the algorithm is not present in the protected headers (e.g., from compile_token) +func TestIssue18AlgorithmInference(t *testing.T) { + // Test the algorithm inference function with different key types + testCases := []struct { + name string + keyData string + expectedAlg string + }{ + { + name: "ECDSA P-256 key", + keyData: testECKeyA, + expectedAlg: "ES256", + }, + { + name: "TFM ECDSA key", + keyData: testTFMECKey, + expectedAlg: "ES256", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + pk := pubKeyFromJWK(t, tc.keyData) + + inferredAlg, err := InferAlgorithmFromPublicKey(pk) + assert.NoError(t, err, "Should be able to infer algorithm") + assert.Equal(t, tc.expectedAlg, inferredAlg.String(), "Should infer correct algorithm") + }) + } + + // Test that verification still works normally (regression test) + tokenSigner := signerFromJWK(t, testECKeyA) + claims := mustBuildValidP2Claims(t, false) + + var evidence Evidence + err := evidence.SetClaims(claims) + require.NoError(t, err) + + cwt, err := evidence.ValidateAndSign(tokenSigner) + require.NoError(t, err) + + evidenceOut, err := DecodeAndValidateEvidenceFromCOSE(cwt) + require.NoError(t, err) + + pk := pubKeyFromJWK(t, testECKeyA) + err = evidenceOut.Verify(pk) + assert.NoError(t, err, "Normal verification should still work after the fix") +}