Skip to content

Commit bb5085e

Browse files
committed
refactor: reduce unmarshalStruct complexity
Note: this allow to reuse some logic in unmarshalParameter
1 parent 6842806 commit bb5085e

File tree

1 file changed

+167
-175
lines changed

1 file changed

+167
-175
lines changed

tpm2/reflect.go

Lines changed: 167 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,165 @@ func unmarshalArray(buf *bytes.Buffer, v reflect.Value) error {
455455
return nil
456456
}
457457

458+
// unmarshalStructField unmarshals a single field of a struct.
459+
// Returns nil if successful or if the field was skipped (e.g., optional field with zero size).
460+
func unmarshalStructField(buf *bytes.Buffer, v reflect.Value, i int) error {
461+
fieldType := v.Type().Field(i)
462+
fieldValue := v.Field(i)
463+
464+
if hasTag(fieldType, "skip") {
465+
return nil
466+
}
467+
468+
list := hasTag(fieldType, "list")
469+
if list && (fieldValue.Kind() != reflect.Slice) {
470+
return fmt.Errorf("field '%v' of struct '%v' had the 'list' tag but was not a slice",
471+
fieldType.Name, v.Type().Name())
472+
}
473+
// Slices of anything but byte/uint8 must have the 'list' tag.
474+
if !list && (fieldValue.Kind() == reflect.Slice) && (fieldType.Type.Elem().Kind() != reflect.Uint8) {
475+
return fmt.Errorf("field '%v' of struct '%v' was a slice of non-byte but did not have the 'list' tag",
476+
fieldType.Name, v.Type().Name())
477+
}
478+
479+
if hasTag(fieldType, "optional") {
480+
// Special case: Part 3 specifies some input/output
481+
// parameters as "optional", which means that they are
482+
// (2B-) sized fields that can be zero-length, even if the
483+
// enclosed type has no legal empty serialization.
484+
// When unmarshalling an optional field, test for zero size
485+
// and skip if empty.
486+
if buf.Len() >= 2 {
487+
if binary.BigEndian.Uint16(buf.Bytes()) == 0 {
488+
// Advance the buffer past the zero size and skip to the
489+
// next field of the struct.
490+
buf.Next(2)
491+
return nil
492+
}
493+
// If non-zero size, proceed to unmarshal the contents below.
494+
}
495+
}
496+
497+
// Handle nullable fields (for command parameters)
498+
if fieldValue.Kind() == reflect.Uint32 && hasTag(fieldType, "nullable") {
499+
var val uint32
500+
if err := binary.Read(buf, binary.BigEndian, &val); err != nil {
501+
return fmt.Errorf("reading nullable uint32 parameter: %w", err)
502+
}
503+
fieldValue.SetUint(uint64(val))
504+
return nil
505+
} else if fieldValue.Kind() == reflect.Uint16 && hasTag(fieldType, "nullable") {
506+
var val uint16
507+
if err := binary.Read(buf, binary.BigEndian, &val); err != nil {
508+
return fmt.Errorf("reading nullable uint16 parameter: %w", err)
509+
}
510+
fieldValue.SetUint(uint64(val))
511+
return nil
512+
}
513+
514+
sized := hasTag(fieldType, "sized")
515+
sized8 := hasTag(fieldType, "sized8")
516+
// If sized, unmarshal a size field first, then restrict
517+
// unmarshalling to the given size
518+
bufToReadFrom := buf
519+
if sized {
520+
var expectedSize uint16
521+
binary.Read(buf, binary.BigEndian, &expectedSize)
522+
sizedBufArray := make([]byte, int(expectedSize))
523+
n, err := buf.Read(sizedBufArray)
524+
if n != int(expectedSize) {
525+
return fmt.Errorf("ran out of data reading sized parameter '%v' inside struct of type '%v'",
526+
fieldType.Name, v.Type().Name())
527+
}
528+
if err != nil {
529+
return fmt.Errorf("error reading data for parameter '%v' inside struct of type '%v'",
530+
fieldType.Name, v.Type().Name())
531+
}
532+
bufToReadFrom = bytes.NewBuffer(sizedBufArray)
533+
}
534+
if sized8 {
535+
var expectedSize uint8
536+
binary.Read(buf, binary.BigEndian, &expectedSize)
537+
sizedBufArray := make([]byte, int(expectedSize))
538+
n, err := buf.Read(sizedBufArray)
539+
if n != int(expectedSize) {
540+
return fmt.Errorf("ran out of data reading sized parameter '%v' inside struct of type '%v'",
541+
fieldType.Name, v.Type().Name())
542+
}
543+
if err != nil {
544+
return fmt.Errorf("error reading data for parameter '%v' inside struct of type '%v'",
545+
fieldType.Name, v.Type().Name())
546+
}
547+
bufToReadFrom = bytes.NewBuffer(sizedBufArray)
548+
}
549+
550+
tagName, _ := tag(fieldType, "tag")
551+
if tagName != "" {
552+
// Make a pass to create a map of tag values
553+
// UInt64-valued fields with values greater than
554+
// MaxInt64 cannot be selectors.
555+
possibleSelectors := make(map[string]int64)
556+
for j := 0; j < v.NumField(); j++ {
557+
switch v.Field(j).Kind() {
558+
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
559+
possibleSelectors[v.Type().Field(j).Name] = v.Field(j).Int()
560+
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
561+
val := v.Field(j).Uint()
562+
if val <= math.MaxInt64 {
563+
possibleSelectors[v.Type().Field(j).Name] = int64(val)
564+
}
565+
}
566+
}
567+
// Check that the tagged value was present (and numeric
568+
// and smaller than MaxInt64)
569+
tagValue, ok := possibleSelectors[tagName]
570+
// Don't marshal anything if the tag value was TPM_ALG_NULL
571+
if tagValue == int64(TPMAlgNull) {
572+
return nil
573+
}
574+
if !ok {
575+
return fmt.Errorf("union tag '%v' for member '%v' of struct '%v' did not reference "+
576+
"a numeric field of in64-compatible value",
577+
tagName, fieldType.Name, v.Type().Name())
578+
}
579+
var uwh unmarshallableWithHint
580+
if fieldValue.CanAddr() && fieldValue.Addr().Type().AssignableTo(reflect.TypeOf(&uwh).Elem()) {
581+
u := fieldValue.Addr().Interface().(unmarshallableWithHint)
582+
contents, err := u.create(tagValue)
583+
if err != nil {
584+
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
585+
}
586+
err = unmarshal(buf, contents)
587+
if err != nil {
588+
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
589+
}
590+
} else if fieldValue.Type().AssignableTo(reflect.TypeOf(&uwh).Elem()) {
591+
u := fieldValue.Interface().(unmarshallableWithHint)
592+
contents, err := u.create(tagValue)
593+
if err != nil {
594+
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
595+
}
596+
err = unmarshal(buf, contents)
597+
if err != nil {
598+
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
599+
}
600+
}
601+
} else {
602+
if err := unmarshal(bufToReadFrom, fieldValue); err != nil {
603+
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
604+
}
605+
}
606+
607+
if sized || sized8 {
608+
if bufToReadFrom.Len() != 0 {
609+
return fmt.Errorf("extra data at the end of sized parameter '%v' inside struct of type '%v'",
610+
fieldType.Name, v.Type().Name())
611+
}
612+
}
613+
614+
return nil
615+
}
616+
458617
func unmarshalStruct(buf *bytes.Buffer, v reflect.Value) error {
459618
// Check if this is a bitwise-defined structure. This requires all the
460619
// exported members to be bitwise-defined.
@@ -487,133 +646,9 @@ func unmarshalStruct(buf *bytes.Buffer, v reflect.Value) error {
487646
if numBitwise > 0 {
488647
return unmarshalBitwise(buf, v)
489648
}
490-
for i := 0; i < v.NumField(); i++ {
491-
if hasTag(v.Type().Field(i), "skip") {
492-
continue
493-
}
494-
list := hasTag(v.Type().Field(i), "list")
495-
if list && (v.Field(i).Kind() != reflect.Slice) {
496-
return fmt.Errorf("field '%v' of struct '%v' had the 'list' tag but was not a slice",
497-
v.Type().Field(i).Name, v.Type().Name())
498-
}
499-
// Slices of anything but byte/uint8 must have the 'list' tag.
500-
if !list && (v.Field(i).Kind() == reflect.Slice) && (v.Type().Field(i).Type.Elem().Kind() != reflect.Uint8) {
501-
return fmt.Errorf("field '%v' of struct '%v' was a slice of non-byte but did not have the 'list' tag",
502-
v.Type().Field(i).Name, v.Type().Name())
503-
}
504-
if hasTag(v.Type().Field(i), "optional") {
505-
// Special case: Part 3 specifies some input/output
506-
// parameters as "optional", which means that they are
507-
// (2B-) sized fields that can be zero-length, even if the
508-
// enclosed type has no legal empty serialization.
509-
// When unmarshalling an optional field, test for zero size
510-
// and skip if empty.
511-
if buf.Len() < 2 {
512-
if binary.BigEndian.Uint16(buf.Bytes()) == 0 {
513-
// Advance the buffer past the zero size and skip to the
514-
// next field of the struct.
515-
buf.Next(2)
516-
continue
517-
}
518-
// If non-zero size, proceed to unmarshal the contents below.
519-
}
520-
}
521-
sized := hasTag(v.Type().Field(i), "sized")
522-
sized8 := hasTag(v.Type().Field(i), "sized8")
523-
// If sized, unmarshal a size field first, then restrict
524-
// unmarshalling to the given size
525-
bufToReadFrom := buf
526-
if sized {
527-
var expectedSize uint16
528-
binary.Read(buf, binary.BigEndian, &expectedSize)
529-
sizedBufArray := make([]byte, int(expectedSize))
530-
n, err := buf.Read(sizedBufArray)
531-
if n != int(expectedSize) {
532-
return fmt.Errorf("ran out of data reading sized parameter '%v' inside struct of type '%v'",
533-
v.Type().Field(i).Name, v.Type().Name())
534-
}
535-
if err != nil {
536-
return fmt.Errorf("error reading data for parameter '%v' inside struct of type '%v'",
537-
v.Type().Field(i).Name, v.Type().Name())
538-
}
539-
bufToReadFrom = bytes.NewBuffer(sizedBufArray)
540-
}
541-
if sized8 {
542-
var expectedSize uint8
543-
binary.Read(buf, binary.BigEndian, &expectedSize)
544-
sizedBufArray := make([]byte, int(expectedSize))
545-
n, err := buf.Read(sizedBufArray)
546-
if n != int(expectedSize) {
547-
return fmt.Errorf("ran out of data reading sized parameter '%v' inside struct of type '%v'",
548-
v.Type().Field(i).Name, v.Type().Name())
549-
}
550-
if err != nil {
551-
return fmt.Errorf("error reading data for parameter '%v' inside struct of type '%v'",
552-
v.Type().Field(i).Name, v.Type().Name())
553-
}
554-
bufToReadFrom = bytes.NewBuffer(sizedBufArray)
555-
}
556-
tag, _ := tag(v.Type().Field(i), "tag")
557-
if tag != "" {
558-
// Make a pass to create a map of tag values
559-
// UInt64-valued fields with values greater than
560-
// MaxInt64 cannot be selectors.
561-
possibleSelectors := make(map[string]int64)
562-
for j := 0; j < v.NumField(); j++ {
563-
switch v.Field(j).Kind() {
564-
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
565-
possibleSelectors[v.Type().Field(j).Name] = v.Field(j).Int()
566-
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
567-
val := v.Field(j).Uint()
568-
if val <= math.MaxInt64 {
569-
possibleSelectors[v.Type().Field(j).Name] = int64(val)
570-
}
571-
}
572-
}
573-
// Check that the tagged value was present (and numeric
574-
// and smaller than MaxInt64)
575-
tagValue, ok := possibleSelectors[tag]
576-
// Don't marshal anything if the tag value was TPM_ALG_NULL
577-
if tagValue == int64(TPMAlgNull) {
578-
continue
579-
}
580-
if !ok {
581-
return fmt.Errorf("union tag '%v' for member '%v' of struct '%v' did not reference "+
582-
"a numeric field of in64-compatible value",
583-
tag, v.Type().Field(i).Name, v.Type().Name())
584-
}
585-
var uwh unmarshallableWithHint
586-
if v.Field(i).CanAddr() && v.Field(i).Addr().Type().AssignableTo(reflect.TypeOf(&uwh).Elem()) {
587-
u := v.Field(i).Addr().Interface().(unmarshallableWithHint)
588-
contents, err := u.create(tagValue)
589-
if err != nil {
590-
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
591-
}
592-
err = unmarshal(buf, contents)
593-
if err != nil {
594-
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
595-
}
596-
} else if v.Field(i).Type().AssignableTo(reflect.TypeOf(&uwh).Elem()) {
597-
u := v.Field(i).Interface().(unmarshallableWithHint)
598-
contents, err := u.create(tagValue)
599-
if err != nil {
600-
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
601-
}
602-
err = unmarshal(buf, contents)
603-
if err != nil {
604-
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
605-
}
606-
}
607-
} else {
608-
if err := unmarshal(bufToReadFrom, v.Field(i)); err != nil {
609-
return fmt.Errorf("unmarshalling field %v of struct of type '%v', %w", i, v.Type(), err)
610-
}
611-
}
612-
if sized || sized8 {
613-
if bufToReadFrom.Len() != 0 {
614-
return fmt.Errorf("extra data at the end of sized parameter '%v' inside struct of type '%v'",
615-
v.Type().Field(i).Name, v.Type().Name())
616-
}
649+
for i := range v.NumField() {
650+
if err := unmarshalStructField(buf, v, i); err != nil {
651+
return err
617652
}
618653
}
619654
return nil
@@ -858,57 +893,14 @@ func marshalParameter[R any](buf *bytes.Buffer, cmd Command[R, *R], i int) error
858893
// Returns an error if the value is not unmarshallable or if there's insufficient data.
859894
func unmarshalParameter[C Command[R, *R], R any](buf *bytes.Buffer, cmd *C, i int) error {
860895
numHandles := len(taggedMembers(reflect.ValueOf(*cmd), "handle", false))
861-
if numHandles+i >= reflect.TypeOf(*cmd).NumField() {
896+
fieldIndex := numHandles + i
897+
if fieldIndex >= reflect.TypeOf(*cmd).NumField() {
862898
return fmt.Errorf("invalid parameter index %v", i)
863899
}
864-
parm := reflect.ValueOf(cmd).Elem().Field(numHandles + i)
865-
field := reflect.TypeOf(*cmd).Field(numHandles + i)
866900

867-
if hasTag(field, "optional") {
868-
// Special case: Part 3 specifies some input/output
869-
// parameters as "optional", which means that they are
870-
// (2B-) sized fields that can be zero-length, even if the
871-
// enclosed type has no legal empty serialization.
872-
// When unmarshalling an optional field, test for zero size
873-
// and skip if empty.
874-
if buf.Len() >= 2 {
875-
var checkBytes [2]byte
876-
tempBuf := *buf
877-
if err := binary.Read(&tempBuf, binary.BigEndian, &checkBytes); err != nil {
878-
return fmt.Errorf("reading optional parameter size: %w", err)
879-
}
880-
881-
if checkBytes == [2]byte{} {
882-
// This is a nil pointer, consume the bytes and leave the field as nil
883-
binary.Read(buf, binary.BigEndian, &checkBytes)
884-
return nil
885-
}
886-
// Fall through to unmarshal the contents normally
887-
} else {
888-
return fmt.Errorf("not enough data for optional parameter %d", i)
889-
}
890-
}
891-
892-
// Handle nullable fields during unmarshaling
893-
if parm.Kind() == reflect.Uint32 && hasTag(field, "nullable") {
894-
var val uint32
895-
if err := binary.Read(buf, binary.BigEndian, &val); err != nil {
896-
return fmt.Errorf("reading nullable uint32 parameter: %w", err)
897-
}
898-
// TPMRHNull is the default for nullable uint32 fields
899-
parm.SetUint(uint64(val))
900-
return nil
901-
} else if parm.Kind() == reflect.Uint16 && hasTag(field, "nullable") {
902-
var val uint16
903-
if err := binary.Read(buf, binary.BigEndian, &val); err != nil {
904-
return fmt.Errorf("reading nullable uint16 parameter: %w", err)
905-
}
906-
// TPMAlgNull is the default for nullable uint16 fields
907-
parm.SetUint(uint64(val))
908-
return nil
909-
}
910-
911-
return unmarshal(buf, parm)
901+
// Use unmarshalStructField to handle this field with all its tags
902+
cmdValue := reflect.ValueOf(cmd).Elem()
903+
return unmarshalStructField(buf, cmdValue, fieldIndex)
912904
}
913905

914906
// populateHandlesFromNames populates the handle fields of a command with NamedHandles

0 commit comments

Comments
 (0)