diff --git a/algorithm.go b/algorithm.go index ac3702c..2d15d22 100644 --- a/algorithm.go +++ b/algorithm.go @@ -5,7 +5,7 @@ import ( "strconv" ) -// Algorithms supported by this library. +// Signature algorithms supported by this library. // // When using an algorithm which requires hashing, // make sure the associated hash function is linked to the binary. @@ -42,12 +42,9 @@ const ( // PureEdDSA by RFC 8152. AlgorithmEdDSA Algorithm = -8 - - // Reserved value. - AlgorithmReserved Algorithm = 0 ) -// Algorithms known, but not supported by this library. +// Signature algorithms known, but not supported by this library. // // Signers and Verifiers requiring the algorithms below are not // directly supported by this library. They need to be provided @@ -66,6 +63,21 @@ const ( AlgorithmRS512 Algorithm = -259 ) +// Hash algorithms by RFC 9054. +const ( + // SHA-256 by RFC 9054. + AlgorithmSHA256 Algorithm = -16 + + // SHA-384 by RFC 9054. + AlgorithmSHA384 Algorithm = -43 + + // SHA-512 by RFC 9054. + AlgorithmSHA512 Algorithm = -44 +) + +// AlgorithmReserved represents a reserved algorithm value by RFC 9053. +const AlgorithmReserved Algorithm = 0 + // Algorithm represents an IANA algorithm entry in the COSE Algorithms registry. // // # See Also @@ -102,6 +114,12 @@ func (a Algorithm) String() string { return "EdDSA" case AlgorithmReserved: return "Reserved" + case AlgorithmSHA256: + return "SHA-256" + case AlgorithmSHA384: + return "SHA-384" + case AlgorithmSHA512: + return "SHA-512" default: return "Algorithm(" + strconv.FormatInt(int64(a), 10) + ")" } @@ -111,11 +129,11 @@ func (a Algorithm) String() string { // library. func (a Algorithm) hashFunc() crypto.Hash { switch a { - case AlgorithmPS256, AlgorithmES256: + case AlgorithmPS256, AlgorithmES256, AlgorithmSHA256: return crypto.SHA256 - case AlgorithmPS384, AlgorithmES384: + case AlgorithmPS384, AlgorithmES384, AlgorithmSHA384: return crypto.SHA384 - case AlgorithmPS512, AlgorithmES512: + case AlgorithmPS512, AlgorithmES512, AlgorithmSHA512: return crypto.SHA512 default: return 0 diff --git a/algorithm_test.go b/algorithm_test.go index e2985cf..ced0c20 100644 --- a/algorithm_test.go +++ b/algorithm_test.go @@ -26,6 +26,9 @@ func TestAlgorithm_String(t *testing.T) { {AlgorithmES512, "ES512"}, {AlgorithmEdDSA, "EdDSA"}, {AlgorithmReserved, "Reserved"}, + {AlgorithmSHA256, "SHA-256"}, + {AlgorithmSHA384, "SHA-384"}, + {AlgorithmSHA512, "SHA-512"}, {7, "Algorithm(7)"}, } for _, tt := range tests { @@ -37,6 +40,36 @@ func TestAlgorithm_String(t *testing.T) { } } +func TestAlgorithm_hashFunc(t *testing.T) { + tests := []struct { + alg Algorithm + want crypto.Hash + }{ + {AlgorithmPS256, crypto.SHA256}, + {AlgorithmPS384, crypto.SHA384}, + {AlgorithmPS512, crypto.SHA512}, + {AlgorithmRS256, 0}, // crypto.SHA256 but not supported as intended + {AlgorithmRS384, 0}, // crypto.SHA384 but not supported as intended + {AlgorithmRS512, 0}, // crypto.SHA512 but not supported as intended + {AlgorithmES256, crypto.SHA256}, + {AlgorithmES384, crypto.SHA384}, + {AlgorithmES512, crypto.SHA512}, + {AlgorithmEdDSA, 0}, + {AlgorithmReserved, 0}, + {AlgorithmSHA256, crypto.SHA256}, + {AlgorithmSHA384, crypto.SHA384}, + {AlgorithmSHA512, crypto.SHA512}, + {7, 0}, + } + for _, tt := range tests { + t.Run(tt.alg.String(), func(t *testing.T) { + if got := tt.alg.hashFunc(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Algorithm.hashFunc() = %v, want %v", got, tt.want) + } + }) + } +} + func TestAlgorithm_computeHash(t *testing.T) { // run tests data := []byte("hello world") diff --git a/example_test.go b/example_test.go index c941b65..204ba2b 100644 --- a/example_test.go +++ b/example_test.go @@ -480,3 +480,66 @@ func ExampleCountersignature() { // signature countersignature verified // verification error as expected } + +// This example demonstrates signing and verifying COSE Hash Envelope. +// +// Reference: https://www.ietf.org/archive/id/draft-ietf-cose-hash-envelope-05.html +// +// Notice: The COSE Hash Envelope API is EXPERIMENTAL and may be changed or +// removed in a later release. +func Example_hashEnvelope() { + // create message to be signed + digested := sha512.Sum512([]byte("hello world")) + payload := cose.HashEnvelopePayload{ + HashAlgorithm: cose.AlgorithmSHA512, + HashValue: digested[:], + PreimageContentType: "text/plain", + Location: "urn:example:location", + } + + // create a signer + privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + panic(err) + } + signer, err := cose.NewSigner(cose.AlgorithmES512, privateKey) + if err != nil { + panic(err) + } + + // sign message + sig, err := cose.SignHashEnvelope(rand.Reader, signer, cose.Headers{ + Protected: cose.ProtectedHeader{ + cose.HeaderLabelAlgorithm: cose.AlgorithmES512, + }, + }, payload) + if err != nil { + panic(err) + } + fmt.Println("message signed") + + // create a verifier from a trusted public key + publicKey := privateKey.Public() + verifier, err := cose.NewVerifier(cose.AlgorithmES512, publicKey) + if err != nil { + panic(err) + } + + // verify message + msg, err := cose.VerifyHashEnvelope(verifier, sig) + if err != nil { + panic(err) + } + fmt.Println("message verified") + + // check payload + fmt.Printf("payload hash: %v: %x\n", msg.Headers.Protected[cose.HeaderLabelPayloadHashAlgorithm], msg.Payload) + fmt.Println("payload content type:", msg.Headers.Protected[cose.HeaderLabelPayloadPreimageContentType]) + fmt.Println("payload location:", msg.Headers.Protected[cose.HeaderLabelPayloadLocation]) + // Output: + // message signed + // message verified + // payload hash: SHA-512: 309ecc489c12d6eb4cc40f50c902f2b4d0ed77ee511a7c7a9bcd3ca86d4cd86f989dd35bc5ff499670da34255b45b0cfd830e81f605dcf7dc5542e93ae9cd76f + // payload content type: text/plain + // payload location: urn:example:location +} diff --git a/hash_envelope.go b/hash_envelope.go new file mode 100644 index 0000000..040eab1 --- /dev/null +++ b/hash_envelope.go @@ -0,0 +1,207 @@ +package cose + +import ( + "errors" + "fmt" + "io" + "maps" +) + +// HashEnvelopePayload indicates the payload of a Hash_Envelope object. +// It is used by the [SignHashEnvelope] function. +// +// # Experimental +// +// Notice: The COSE Hash Envelope API is EXPERIMENTAL and may be changed or +// removed in a later release. +type HashEnvelopePayload struct { + // HashAlgorithm is the hash algorithm used to produce the hash value. + HashAlgorithm Algorithm + + // HashValue is the hash value of the payload. + HashValue []byte + + // PreimageContentType is the content type of the data that has been hashed. + // The value is either an unsigned integer (RFC 7252 Section 12.3) or a + // string (RFC 9110 Section 8.3). + // This field is optional. + // + // References: + // - https://www.iana.org/assignments/core-parameters/core-parameters.xhtml + // - https://www.iana.org/assignments/media-types/media-types.xhtml + PreimageContentType any // uint / string + + // Location is the location of the hash value in the payload. + // This field is optional. + Location string +} + +// SignHashEnvelope signs a [Sign1Message] using the provided [Signer] and +// produces a Hash_Envelope object. +// +// Hash_Envelope_Protected_Header = { +// ? &(alg: 1) => int, +// &(payload_hash_alg: 258) => int +// &(payload_preimage_content_type: 259) => uint / tstr +// ? &(payload_location: 260) => tstr +// * int / tstr => any +// } +// +// Hash_Envelope_Unprotected_Header = { +// * int / tstr => any +// } +// +// Hash_Envelope_as_COSE_Sign1 = [ +// protected : bstr .cbor Hash_Envelope_Protected_Header, +// unprotected : Hash_Envelope_Unprotected_Header, +// payload: bstr / nil, +// signature : bstr +// ] +// +// Hash_Envelope = #6.18(Hash_Envelope_as_COSE_Sign1) +// +// Reference: https://www.ietf.org/archive/id/draft-ietf-cose-hash-envelope-05.html +// +// # Experimental +// +// Notice: The COSE Hash Envelope API is EXPERIMENTAL and may be changed or +// removed in a later release. +func SignHashEnvelope(rand io.Reader, signer Signer, headers Headers, payload HashEnvelopePayload) ([]byte, error) { + if err := validateHash(payload.HashAlgorithm, payload.HashValue); err != nil { + return nil, err + } + + headers.Protected = setHashEnvelopeProtectedHeader(headers.Protected, &payload) + headers.RawProtected = nil + if err := validateHashEnvelopeHeaders(&headers); err != nil { + return nil, err + } + + return Sign1(rand, signer, headers, payload.HashValue, nil) +} + +// VerifyHashEnvelope verifies a Hash_Envelope object using the provided +// [Verifier]. +// It returns the decoded [Sign1Message] if the verification is successful. +// +// # Experimental +// +// Notice: The COSE Hash Envelope API is EXPERIMENTAL and may be changed or +// removed in a later release. +func VerifyHashEnvelope(verifier Verifier, envelope []byte) (*Sign1Message, error) { + // parse and validate the Hash_Envelope object + var message Sign1Message + if err := message.UnmarshalCBOR(envelope); err != nil { + return nil, err + } + if err := validateHashEnvelopeHeaders(&message.Headers); err != nil { + return nil, err + } + + // verify the Hash_Envelope object + if err := message.Verify(nil, verifier); err != nil { + return nil, err + } + + // cast to type Algorithm + hashAlgorithm, err := message.Headers.Protected.PayloadHashAlgorithm() + if err != nil { + return nil, err + } + message.Headers.Protected[HeaderLabelPayloadHashAlgorithm] = hashAlgorithm + + // validate the hash value + if err := validateHash(hashAlgorithm, message.Payload); err != nil { + return nil, err + } + + return &message, nil +} + +// validateHash checks the validity of the known hash. +func validateHash(alg Algorithm, value []byte) error { + hash := alg.hashFunc() + if hash == 0 { + return nil // no check on unsupported hash algorithms + } + if size := hash.Size(); size != len(value) { + return fmt.Errorf("%v: size mismatch: expected %d, got %d", alg, size, len(value)) + } + return nil +} + +// setHashEnvelopeProtectedHeader sets the protected header for a Hash_Envelope +// object. +func setHashEnvelopeProtectedHeader(base ProtectedHeader, payload *HashEnvelopePayload) ProtectedHeader { + header := maps.Clone(base) + if header == nil { + header = make(ProtectedHeader) + } + header[HeaderLabelPayloadHashAlgorithm] = payload.HashAlgorithm + if payload.PreimageContentType != nil { + header[HeaderLabelPayloadPreimageContentType] = payload.PreimageContentType + } + if payload.Location != "" { + header[HeaderLabelPayloadLocation] = payload.Location + } + return header +} + +// validateHashEnvelopeHeaders validates the headers of a Hash_Envelope object. +// See https://www.ietf.org/archive/id/draft-ietf-cose-hash-envelope-05.html +// section 4 for more details. +func validateHashEnvelopeHeaders(headers *Headers) error { + var foundPayloadHashAlgorithm bool + for label, value := range headers.Protected { + // Validate that all header labels are integers or strings. + // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4 + label, ok := normalizeLabel(label) + if !ok { + return errors.New("header label: require int / tstr type") + } + + switch label { + case HeaderLabelContentType: + return errors.New("protected header parameter: content type: not allowed") + case HeaderLabelPayloadHashAlgorithm: + _, isAlg := value.(Algorithm) + if !isAlg && !canInt(value) { + return errors.New("protected header parameter: payload hash alg: require int type") + } + foundPayloadHashAlgorithm = true + case HeaderLabelPayloadPreimageContentType: + if !canUint(value) && !canTstr(value) { + return errors.New("protected header parameter: payload preimage content type: require uint / tstr type") + } + case HeaderLabelPayloadLocation: + if !canTstr(value) { + return errors.New("protected header parameter: payload location: require tstr type") + } + } + } + if !foundPayloadHashAlgorithm { + return errors.New("protected header parameter: payload hash alg: required") + } + + for label := range headers.Unprotected { + // Validate that all header labels are integers or strings. + // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4 + label, ok := normalizeLabel(label) + if !ok { + return errors.New("header label: require int / tstr type") + } + + switch label { + case HeaderLabelContentType: + return errors.New("unprotected header parameter: content type: not allowed") + case HeaderLabelPayloadHashAlgorithm: + return errors.New("unprotected header parameter: payload hash alg: not allowed") + case HeaderLabelPayloadPreimageContentType: + return errors.New("unprotected header parameter: payload preimage content type: not allowed") + case HeaderLabelPayloadLocation: + return errors.New("unprotected header parameter: payload location: not allowed") + } + } + + return nil +} diff --git a/hash_envelope_test.go b/hash_envelope_test.go new file mode 100644 index 0000000..cd5bf20 --- /dev/null +++ b/hash_envelope_test.go @@ -0,0 +1,422 @@ +package cose + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "maps" + "reflect" + "testing" +) + +func TestSignHashEnvelope(t *testing.T) { + // generate key and set up signer / verifier + alg := AlgorithmES256 + key := generateTestECDSAKey(t) + signer, err := NewSigner(alg, key) + if err != nil { + t.Fatalf("NewSigner() error = %v", err) + } + verifier, err := NewVerifier(alg, key.Public()) + if err != nil { + t.Fatalf("NewVerifier() error = %v", err) + } + payload := []byte("hello world") + payloadAlg := AlgorithmSHA256 + payloadSHA256 := sha256.Sum256(payload) + payloadHash := payloadSHA256[:] + + tests := []struct { + name string + headers Headers + payload HashEnvelopePayload + wantHeaders Headers + wantErr string + }{ + { + name: "minimal signing", + payload: HashEnvelopePayload{ + HashAlgorithm: payloadAlg, + HashValue: payloadHash, + }, + wantHeaders: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + }, + }, + { + name: "with preimage content type (int)", + payload: HashEnvelopePayload{ + HashAlgorithm: payloadAlg, + HashValue: payloadHash, + PreimageContentType: 0, + }, + wantHeaders: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + HeaderLabelPayloadPreimageContentType: int64(0), + }, + }, + }, + { + name: "with preimage content type (tstr)", + payload: HashEnvelopePayload{ + HashAlgorithm: payloadAlg, + HashValue: payloadHash, + PreimageContentType: "text/plain", + }, + wantHeaders: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + HeaderLabelPayloadPreimageContentType: "text/plain", + }, + }, + }, + { + name: "with payload location", + payload: HashEnvelopePayload{ + HashAlgorithm: payloadAlg, + HashValue: payloadHash, + Location: "urn:example:location", + }, + wantHeaders: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + HeaderLabelPayloadLocation: "urn:example:location", + }, + }, + }, + { + name: "full signing with base headers", + headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + payload: HashEnvelopePayload{ + HashAlgorithm: payloadAlg, + HashValue: payloadHash, + PreimageContentType: "text/plain", + Location: "urn:example:location", + }, + wantHeaders: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + HeaderLabelPayloadPreimageContentType: "text/plain", + HeaderLabelPayloadLocation: "urn:example:location", + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + }, + { + name: "unsupported hash algorithm", + payload: HashEnvelopePayload{ + HashAlgorithm: Algorithm(-15), // SHA-256/64 + HashValue: payloadHash, + }, + wantHeaders: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: Algorithm(-15), // SHA-256/64 + }, + }, + }, + { + name: "bad hash value", + payload: HashEnvelopePayload{ + HashAlgorithm: payloadAlg, + }, + wantErr: "SHA-256: size mismatch: expected 32, got 0", + }, + { + name: "invalid preimage content type", + payload: HashEnvelopePayload{ + HashAlgorithm: payloadAlg, + HashValue: payloadHash, + PreimageContentType: -1, + }, + wantErr: "protected header parameter: payload preimage content type: require uint / tstr type", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := SignHashEnvelope(rand.Reader, signer, tt.headers, tt.payload) + if err != nil { + if tt.wantErr == "" || err.Error() != tt.wantErr { + t.Fatalf("SignHashEnvelope() error = %v, wantErr %s", err, tt.wantErr) + } + return + } + if tt.wantErr != "" { + t.Fatalf("SignHashEnvelope() error = %v, wantErr %s", err, tt.wantErr) + } + msg, err := VerifyHashEnvelope(verifier, got) + if err != nil { + t.Fatalf("VerifyHashEnvelope() error = %v", err) + } + if !maps.EqualFunc(msg.Headers.Protected, tt.wantHeaders.Protected, reflect.DeepEqual) { + t.Errorf("SignHashEnvelope() Protected Header = %v, want %v", msg.Headers.Protected, tt.wantHeaders.Protected) + } + if !maps.EqualFunc(msg.Headers.Unprotected, tt.wantHeaders.Unprotected, reflect.DeepEqual) { + t.Errorf("SignHashEnvelope() Unprotected Header = %v, want %v", msg.Headers.Unprotected, tt.wantHeaders.Unprotected) + } + if !bytes.Equal(msg.Payload, tt.payload.HashValue) { + t.Errorf("SignHashEnvelope() Payload = %v, want %v", msg.Payload, tt.payload.HashValue) + } + }) + } +} + +func TestVerifyHashEnvelope(t *testing.T) { + // generate key and set up signer / verifier + alg := AlgorithmES256 + key := generateTestECDSAKey(t) + signer, err := NewSigner(alg, key) + if err != nil { + t.Fatalf("NewSigner() error = %v", err) + } + verifier, err := NewVerifier(alg, key.Public()) + if err != nil { + t.Fatalf("NewVerifier() error = %v", err) + } + payload := []byte("hello world") + payloadAlg := AlgorithmSHA256 + payloadSHA256 := sha256.Sum256(payload) + payloadHash := payloadSHA256[:] + + tests := []struct { + name string + envelope []byte + message *Sign1Message + wantErr string + }{ + { + name: "valid envelope", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + }, + Payload: payloadHash, + }, + }, + { + name: "nil envelope", + wantErr: "cbor: invalid COSE_Sign1_Tagged object", + }, + { + name: "empty envelope", + envelope: []byte{}, + wantErr: "cbor: invalid COSE_Sign1_Tagged object", + }, + { + name: "not a Hash_Envelope object", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + }, + }, + Payload: payloadHash, + }, + wantErr: "protected header parameter: payload hash alg: required", + }, + + { + name: "payload hash algorithm in the unprotected header", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: payloadAlg, + HeaderLabelAlgorithm: alg, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + }, + Payload: payloadHash, + }, + wantErr: "unprotected header parameter: payload hash alg: not allowed", + }, + { + name: "invalid payload hash algorithm", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: "SHA-256", + }, + }, + Payload: payloadHash, + }, + wantErr: "protected header parameter: payload hash alg: require int type", + }, + { + name: "invalid preimage content type in the protected header", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + HeaderLabelPayloadPreimageContentType: -1, + }, + }, + Payload: payloadHash, + }, + wantErr: "protected header parameter: payload preimage content type: require uint / tstr type", + }, + { + name: "preimage content type present in the unprotected header", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelPayloadPreimageContentType: "text/plain", + }, + }, + Payload: payloadHash, + }, + wantErr: "unprotected header parameter: payload preimage content type: not allowed", + }, + { + name: "payload location present in the unprotected header", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelPayloadLocation: "urn:example:location", + }, + }, + Payload: payloadHash, + }, + wantErr: "unprotected header parameter: payload location: not allowed", + }, + { + name: "invalid payload location in the protected header", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + HeaderLabelPayloadLocation: 0, + }, + }, + Payload: payloadHash, + }, + wantErr: "protected header parameter: payload location: require tstr type", + }, + { + name: "content type present in the protected header", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelContentType: "text/plain", + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + }, + Payload: payloadHash, + }, + wantErr: "protected header parameter: content type: not allowed", + }, + { + name: "content type present in the unprotected header", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelContentType: "text/plain", + }, + }, + Payload: payloadHash, + }, + wantErr: "unprotected header parameter: content type: not allowed", + }, + { + name: "bad signature", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + Payload: payloadHash, + Signature: []byte("bad signature"), + }, + wantErr: "verification error", + }, + { + name: "bad hash value", + message: &Sign1Message{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: alg, + HeaderLabelPayloadHashAlgorithm: payloadAlg, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + Payload: []byte("bad hash value"), + }, + wantErr: "SHA-256: size mismatch: expected 32, got 14", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + envelope := tt.envelope + if tt.message != nil { + message := tt.message + if message.Signature == nil { + if err := message.Sign(rand.Reader, nil, signer); err != nil { + t.Fatalf("Sign1Message.Sign() error = %v", err) + } + } + var err error + envelope, err = message.MarshalCBOR() + if err != nil { + t.Fatalf("Sign1Message.MarshalCBOR() error = %v", err) + } + } + msg, err := VerifyHashEnvelope(verifier, envelope) + if err != nil { + if tt.wantErr == "" || err.Error() != tt.wantErr { + t.Fatalf("VerifyHashEnvelope() error = %v, wantErr %s", err, tt.wantErr) + } + return + } + if tt.wantErr != "" { + t.Fatalf("VerifyHashEnvelope() error = %v, wantErr %s", err, tt.wantErr) + } + if msg == nil { + t.Fatalf("VerifyHashEnvelope() message = nil, want not nil") + } + }) + } +} diff --git a/headers.go b/headers.go index c4e980f..acffbe2 100644 --- a/headers.go +++ b/headers.go @@ -31,6 +31,18 @@ const ( HeaderLabelX5U int64 = 35 ) +// Temporary COSE Header labels registered in the IANA "COSE Header Parameters" +// registry. +// These labels are not intended to be used in production code and are subject +// to change without notice. +// +// Reference: https://www.iana.org/assignments/cose/cose.xhtml#header-parameters +const ( + HeaderLabelPayloadHashAlgorithm int64 = 258 // registered 2025-03-05, expires 2026-03-05 + HeaderLabelPayloadPreimageContentType int64 = 259 // registered 2025-03-05, expires 2026-03-05 + HeaderLabelPayloadLocation int64 = 260 // registered 2025-03-05, expires 2026-03-05 +) + // ProtectedHeader contains parameters that are to be cryptographically // protected. type ProtectedHeader map[any]any @@ -132,7 +144,7 @@ func (h ProtectedHeader) SetCWTClaims(claims CWTClaims) (CWTClaims, error) { func (h ProtectedHeader) Algorithm() (Algorithm, error) { value, ok := h[HeaderLabelAlgorithm] if !ok { - return 0, ErrAlgorithmNotFound + return AlgorithmReserved, ErrAlgorithmNotFound } switch alg := value.(type) { case Algorithm: @@ -148,7 +160,37 @@ func (h ProtectedHeader) Algorithm() (Algorithm, error) { case int64: return Algorithm(alg), nil case string: - return AlgorithmReserved, fmt.Errorf("Algorithm(%q)", alg) + return AlgorithmReserved, fmt.Errorf("Algorithm(%q): %w", alg, ErrAlgorithmNotSupported) + default: + return AlgorithmReserved, ErrInvalidAlgorithm + } +} + +// PayloadHashAlgorithm gets the payload hash algorithm value from the protected +// header. +// +// # Experimental +// +// Notice: The COSE Hash Envelope API is EXPERIMENTAL and may be changed or +// removed in a later release. +func (h ProtectedHeader) PayloadHashAlgorithm() (Algorithm, error) { + value, ok := h[HeaderLabelPayloadHashAlgorithm] + if !ok { + return AlgorithmReserved, ErrAlgorithmNotFound + } + switch alg := value.(type) { + case Algorithm: + return alg, nil + case int: + return Algorithm(alg), nil + case int8: + return Algorithm(alg), nil + case int16: + return Algorithm(alg), nil + case int32: + return Algorithm(alg), nil + case int64: + return Algorithm(alg), nil default: return AlgorithmReserved, ErrInvalidAlgorithm } diff --git a/headers_test.go b/headers_test.go index efb11f9..b9a09dc 100644 --- a/headers_test.go +++ b/headers_test.go @@ -524,7 +524,7 @@ func TestProtectedHeader_Algorithm(t *testing.T) { h: ProtectedHeader{ HeaderLabelAlgorithm: "foo", }, - wantErr: errors.New("Algorithm(\"foo\")"), + wantErr: errors.New(`Algorithm("foo"): algorithm not supported`), }, { name: "invalid algorithm", @@ -548,6 +548,101 @@ func TestProtectedHeader_Algorithm(t *testing.T) { } } +func TestProtectedHeader_PayloadHashAlgorithm(t *testing.T) { + tests := []struct { + name string + h ProtectedHeader + want Algorithm + wantErr error + }{ + { + name: "algorithm", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: AlgorithmES256, + }, + want: AlgorithmES256, + }, + { + name: "int", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: int(AlgorithmES256), + }, + want: AlgorithmES256, + }, + { + name: "int8", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: int8(AlgorithmES256), + }, + want: AlgorithmES256, + }, + { + name: "int16", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: int16(AlgorithmES256), + }, + want: AlgorithmES256, + }, + { + name: "int32", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: int32(AlgorithmES256), + }, + want: AlgorithmES256, + }, + { + name: "int64", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: int64(AlgorithmES256), + }, + want: AlgorithmES256, + }, + { + name: "nil header", + h: nil, + wantErr: ErrAlgorithmNotFound, + }, + { + name: "empty header", + h: ProtectedHeader{}, + wantErr: ErrAlgorithmNotFound, + }, + { + name: "missing algorithm header", + h: ProtectedHeader{ + "foo": "bar", + }, + wantErr: ErrAlgorithmNotFound, + }, + { + name: "algorithm in string type is not allowed", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: "foo", + }, + wantErr: ErrInvalidAlgorithm, + }, + { + name: "invalid algorithm", + h: ProtectedHeader{ + HeaderLabelPayloadHashAlgorithm: 2.5, + }, + wantErr: ErrInvalidAlgorithm, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.h.PayloadHashAlgorithm() + if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { + t.Errorf("ProtectedHeader.PayloadHashAlgorithm() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ProtectedHeader.PayloadHashAlgorithm() = %v, want %v", got, tt.want) + } + }) + } +} + func TestProtectedHeader_Critical(t *testing.T) { tests := []struct { name string