From c88a9da99c7be283a90a94200557954045a8bd87 Mon Sep 17 00:00:00 2001 From: Pranjal Kole Date: Tue, 21 Jan 2025 02:29:38 +0530 Subject: [PATCH 1/2] headers: validate CWT claims Signed-off-by: Pranjal Kole --- cwt_test.go | 6 ++++- headers.go | 71 +++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/cwt_test.go b/cwt_test.go index 3ee506f..2af4509 100644 --- a/cwt_test.go +++ b/cwt_test.go @@ -21,7 +21,11 @@ func ExampleCWTClaims() { cose.CWTClaimIssuer: "issuer.example", cose.CWTClaimSubject: "subject.example", } - msgToSign.Headers.Protected.SetCWTClaims(claims) + + claims, err := msgToSign.Headers.Protected.SetCWTClaims(claims) + if err != nil { + panic(err) + } msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1") diff --git a/headers.go b/headers.go index c4e980f..12454b5 100644 --- a/headers.go +++ b/headers.go @@ -115,15 +115,55 @@ func (h ProtectedHeader) SetType(typ any) (any, error) { // SetCWTClaims sets the CWT Claims value of the protected header. func (h ProtectedHeader) SetCWTClaims(claims CWTClaims) (CWTClaims, error) { - iss, hasIss := claims[1] - if hasIss && !canTstr(iss) { - return claims, errors.New("cwt claim: iss: require tstr") - } - sub, hasSub := claims[2] - if hasSub && !canTstr(sub) { - return claims, errors.New("cwt claim: sub: require tstr") + for name, _ := range claims { + switch name { + case 1: + iss, hasIss := claims[name] + if hasIss && !canTstr(iss) { + return claims, errors.New("cwt claim: iss: require tstr") + } + case 2: + sub, hasSub := claims[name] + if hasSub && !canTstr(sub) { + return claims, errors.New("cwt claim: sub: require tstr") + } + case 3: + aud, hasAud := claims[name] + if hasAud && !canTstr(aud) { + return claims, errors.New("cwt claim: aud: require tstr") + } + case 4: + exp, hasExp := claims[name] + if hasExp && !canInt(exp) && !canFloat(exp) { + return claims, errors.New("cwt claim: exp: require int or float") + } + case 5: + nbf, hasNbf := claims[name] + if hasNbf && !canInt(nbf) && !canFloat(nbf) { + return claims, errors.New("cwt claim: nbf: require int or float") + } + case 6: + iat, hasIat := claims[name] + if hasIat && !canInt(iat) && !canFloat(iat) { + return claims, errors.New("cwt claim: iat: require int or float") + } + case 7: + cti, hasCti := claims[name] + if hasCti && !canBstr(cti) { + return claims, errors.New("cwt claim: cti: require tstr") + } + case 8: + cnf, hasCnf := claims[name] + if hasCnf && !canMap(cnf) { + return claims, errors.New("cwt claim: cnf: require map") + } + case 9: + scope, hasScope := claims[name] + if hasScope && !canBstr(scope) && !canTstr(scope) { + return claims, errors.New("cwt claim: scope: require bstr or tstr") + } + } } - // TODO: validate claims, other claims h[HeaderLabelCWTClaims] = claims return claims, nil } @@ -620,6 +660,15 @@ func canInt(v any) bool { return false } +// canFloat reports whether v can be used as a CBOR float type +func canFloat(v any) bool { + switch v.(type) { + case float32, float64: + return true + } + return false +} + // canTstr reports whether v can be used as a CBOR tstr type. func canTstr(v any) bool { _, ok := v.(string) @@ -632,6 +681,12 @@ func canBstr(v any) bool { return ok } +// canMap reports whether v can be used as a CBOR map type. +func canMap(v any) bool { + _, ok := v.(map[any]any) + return ok +} + // normalizeLabel tries to cast label into a int64 or a string. // Returns (nil, false) if the label type is not valid. func normalizeLabel(label any) (any, bool) { From 2710c08f2ad5934188d259fa2f2ffbb4f93a4d6c Mon Sep 17 00:00:00 2001 From: Steve Lasker Date: Fri, 7 Nov 2025 08:05:07 +0200 Subject: [PATCH 2/2] Update headers.go Co-authored-by: Shiwei Zhang Signed-off-by: Steve Lasker --- headers.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/headers.go b/headers.go index 12454b5..03c5a7e 100644 --- a/headers.go +++ b/headers.go @@ -115,16 +115,14 @@ func (h ProtectedHeader) SetType(typ any) (any, error) { // SetCWTClaims sets the CWT Claims value of the protected header. func (h ProtectedHeader) SetCWTClaims(claims CWTClaims) (CWTClaims, error) { - for name, _ := range claims { + for name, value := range claims { switch name { - case 1: - iss, hasIss := claims[name] - if hasIss && !canTstr(iss) { + case CWTClaimIssuer: + if !canTstr(value) { return claims, errors.New("cwt claim: iss: require tstr") } - case 2: - sub, hasSub := claims[name] - if hasSub && !canTstr(sub) { + case CWTClaimSubject: + if !canTstr(value) { return claims, errors.New("cwt claim: sub: require tstr") } case 3: