diff --git a/go.mod b/go.mod index 6ff1dab..b06ad71 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 github.com/uptrace/bun/driver/sqliteshim v1.2.15 github.com/uptrace/bun/extra/bundebug v1.2.15 - github.com/veraison/corim v1.1.3-0.20251209103150-ef6dbb7ed63f + github.com/veraison/corim v1.1.3-0.20251212144809-26e0f2a5f59d github.com/veraison/eat v0.0.0-20210331113810-3da8a4dd42ff github.com/veraison/swid v1.1.1-0.20251003121634-fd1f7f1e1897 ) diff --git a/go.sum b/go.sum index 29218b0..522b1f8 100644 --- a/go.sum +++ b/go.sum @@ -129,8 +129,8 @@ github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMT github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk= github.com/uptrace/bun/extra/bundebug v1.2.15 h1:IY2Z/pVyVg0ApWnQ/pEnwe6BWxlDDATCz7IFZghutCs= github.com/uptrace/bun/extra/bundebug v1.2.15/go.mod h1:JuE+BT7NjTZ9UKr74eC8s9yZ9dnQCeufDwFRTC8w3Xo= -github.com/veraison/corim v1.1.3-0.20251209103150-ef6dbb7ed63f h1:ANVwskQLZ0YEzivFZqreGIftxakfd579fpOcjU8rHjo= -github.com/veraison/corim v1.1.3-0.20251209103150-ef6dbb7ed63f/go.mod h1:96PQ0lk+O9bzutKTDz66G2DaARYUp1BeR06EYwEwSH0= +github.com/veraison/corim v1.1.3-0.20251212144809-26e0f2a5f59d h1:ifSo+6zUb8I7yOBF1MgZflJwXmCYiZmiemd74VDcJW0= +github.com/veraison/corim v1.1.3-0.20251212144809-26e0f2a5f59d/go.mod h1:96PQ0lk+O9bzutKTDz66G2DaARYUp1BeR06EYwEwSH0= github.com/veraison/eat v0.0.0-20210331113810-3da8a4dd42ff h1:r6I2eJL/z8dp5flsQIKHMeDjyV6UO8If3MaVBLvTjF4= github.com/veraison/eat v0.0.0-20210331113810-3da8a4dd42ff/go.mod h1:+kxt8iuFiVvKRs2VQ1Ho7bbAScXAB/kHFFuP5Biw19I= github.com/veraison/go-cose v1.2.1 h1:Gj4x20D0YP79J2+cK3anjGEMwIkg2xX+TKVVGUXwNAc= diff --git a/pkg/model/extension.go b/pkg/model/extension.go index e5b66d5..5e1fa26 100644 --- a/pkg/model/extension.go +++ b/pkg/model/extension.go @@ -17,7 +17,7 @@ func CoRIMExtensionsFromCoRIM(origin corim.Extensions) ([]*ExtensionValue, error } func CoMIDExtensionsFromCoRIM(origin comid.Extensions) ([]*ExtensionValue, error) { - var ret []*ExtensionValue + var ret []*ExtensionValue // nolint: prealloc if origin.IsEmpty() { return ret, nil } @@ -78,6 +78,21 @@ func CoMIDExtensionsFromCoRIM(origin comid.Extensions) ([]*ExtensionValue, error ret = append(ret, &retVal) } + for k, v := range origin.Cached { + bytes, err := cbor.Marshal(v) + if err != nil { + return nil, fmt.Errorf("error CBOR encoding cached extension %s: %w", k, err) + } + + retVal := ExtensionValue{ + FieldName: "", // empty field name indicates cached value + JSONTag: k, + ValueBytes: bytes, + } + + ret = append(ret, &retVal) + } + return ret, nil } @@ -91,12 +106,24 @@ func CoMIDExtensionsToCoRIM(origin []*ExtensionValue) (comid.Extensions, error) return comid.Extensions{}, nil } - values := make([]any, 0, len(origin)) + values := make(map[string]any, len(origin)) fields := make([]reflect.StructField, 0, len(origin)) + cached := make(map[string]any, len(origin)) for _, origVal := range origin { var val any + if origVal.FieldName == "" { + // empty field name means this is a cached value + if err := cbor.Unmarshal(origVal.ValueBytes, &val); err != nil { + return comid.Extensions{}, fmt.Errorf( + "error decoding CBOR for %s: %w", origVal.JSONTag, err) + } + + cached[origVal.JSONTag] = val + continue + } + switch origVal.FieldKind { case reflect.String: val = origVal.ValueText @@ -147,17 +174,21 @@ func CoMIDExtensionsToCoRIM(origin []*ExtensionValue) (comid.Extensions, error) origVal.CBORTag, origVal.JSONTag)), }) - values = append(values, val) + values[origVal.FieldName] = val } structType := reflect.StructOf(fields) structPtr := reflect.New(structType) structValue := structPtr.Elem() - for i, origVal := range origin { + for _, origVal := range origin { + if origVal.FieldName == "" { + continue + } + field := structValue.FieldByName(origVal.FieldName) if field.IsValid() && field.CanSet() { - field.Set(reflect.ValueOf(values[i])) + field.Set(reflect.ValueOf(values[origVal.FieldName])) } else { return comid.Extensions{}, fmt.Errorf("could not set field %q", origVal.FieldName) } @@ -166,6 +197,10 @@ func CoMIDExtensionsToCoRIM(origin []*ExtensionValue) (comid.Extensions, error) var ret comid.Extensions ret.IMapValue = structPtr.Interface() + if len(cached) != 0 { + ret.Cached = cached + } + return ret, nil } diff --git a/pkg/model/extension_test.go b/pkg/model/extension_test.go index 95aaf26..0d56aff 100644 --- a/pkg/model/extension_test.go +++ b/pkg/model/extension_test.go @@ -37,10 +37,14 @@ func TestExtensionValue_round_trip(t *testing.T) { var original comid.Extensions original.Register(&extStruct) + original.Cached = map[string]any{ + "-1": uint64(7), + "fum": false, + } extVals, err := CoMIDExtensionsFromCoRIM(original) assert.NoError(t, err) - assert.Len(t, extVals, 8) + assert.Len(t, extVals, 10) for _, ev := range extVals { ev.OwnerID = 1 @@ -53,7 +57,7 @@ func TestExtensionValue_round_trip(t *testing.T) { err = db.NewSelect().Model(&resVals).Scan(ctx) assert.NoError(t, err) - assert.Len(t, resVals, 8) + assert.Len(t, resVals, 10) returnedExts, err := CoMIDExtensionsToCoRIM(resVals) assert.NoError(t, err) @@ -63,6 +67,8 @@ func TestExtensionValue_round_trip(t *testing.T) { val, err := returnedExts.Get("Qux") assert.NoError(t, err) assert.Equal(t, map[interface{}]interface{}{"Zap": true}, val) + + assert.Equal(t, original.Cached, returnedExts.Cached) } func TestExtensionValue_Select(t *testing.T) {