Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tpm2/marshalling.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,85 @@ func (b *boxed[T]) unmarshal(buf *bytes.Buffer) error {
b.Contents = new(T)
return unmarshal(buf, reflect.ValueOf(b.Contents))
}

// MarshalCommandResponse marshals both command and response.
func MarshalCommandResponse[C Command[R, *R], R any](cmd C, rsp *R) (cmdData []byte, rspData []byte, err error) {
cmdData, err = MarshalCommand(cmd)
if err != nil {
return nil, nil, fmt.Errorf("marshalling command: %w", err)
}
rspData, err = MarshalResponse(rsp)
if err != nil {
return nil, nil, fmt.Errorf("marshalling response: %w", err)
}
return cmdData, rspData, nil
}

// UnmarshalCommandResponse unmarshals both command and response.
func UnmarshalCommandResponse[C Command[R, *R], R any](cmdData []byte, rspData []byte) (cmd C, rsp *R, err error) {
cmd, err = UnmarshalCommand[C, R](cmdData)
if err != nil {
return cmd, rsp, fmt.Errorf("unmarshalling command: %w", err)
}
rsp, err = UnmarshalResponse[R](rspData)
if err != nil {
return cmd, rsp, fmt.Errorf("unmarshalling response: %w", err)
}
return cmd, rsp, nil
}

// MarshalCommand marshals a TPM command.
func MarshalCommand[C Command[R, *R], R any](cmd C) ([]byte, error) {
var buf bytes.Buffer
params := taggedMembers(reflect.ValueOf(cmd), "handle", true)
for i := range len(params) {
if err := marshalParameter(&buf, cmd, i); err != nil {
return nil, fmt.Errorf("marshalling command's parameter: %w", err)
}
}
return buf.Bytes(), nil
}

// UnmarshalCommand unmarshals a TPM command.
func UnmarshalCommand[C Command[R, *R], R any](data []byte) (C, error) {
var cmd C
if data == nil {
return cmd, fmt.Errorf("data cannot be nil")
}
buf := bytes.NewBuffer(data)
params := taggedMembers(reflect.ValueOf(cmd), "handle", true)
for i := range len(params) {
if err := unmarshalParameter(buf, &cmd, i); err != nil {
return cmd, fmt.Errorf("unmarshalling command's parameter: %w", err)
}
}
return cmd, nil
}

// MarshalResponse marshals a TPM response.
func MarshalResponse[R any](rsp *R) ([]byte, error) {
var buf bytes.Buffer
parameters := taggedMembers(reflect.ValueOf(rsp).Elem(), "handle", true)
for i, parameter := range parameters {
if err := marshal(&buf, parameter); err != nil {
return nil, fmt.Errorf("marshalling response parameter %d: %w", i, err)
}
}
return buf.Bytes(), nil
}

// UnmarshalResponse unmarshals a TPM response.
func UnmarshalResponse[R any](data []byte) (*R, error) {
var rsp R
if data == nil {
return nil, fmt.Errorf("data cannot be nil")
}
buf := bytes.NewBuffer(data)
parameters := taggedMembers(reflect.ValueOf(&rsp).Elem(), "handle", true)
for i, parameter := range parameters {
if err := unmarshal(buf, parameter); err != nil {
return nil, fmt.Errorf("unmarshalling response parameter %d: %w", i, err)
}
}
return &rsp, nil
}
48 changes: 48 additions & 0 deletions tpm2/marshalling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package tpm2

import (
"bytes"
"reflect"
"testing"

"github.com/google/go-tpm/tpm2/transport/simulator"
)

func TestMarshal2B(t *testing.T) {
Expand Down Expand Up @@ -154,3 +157,48 @@ func TestMarshalT(t *testing.T) {
t.Errorf("want %x\ngot %x", pubBytes, pub2Bytes)
}
}

func TestMarshalCommandResponse(t *testing.T) {
thetpm, err := simulator.OpenSimulator()
if err != nil {
t.Fatalf("could not connect to TPM simulator: %v", err)
}
defer thetpm.Close()

getCmd := GetCapability{
Capability: TPMCapTPMProperties,
Property: uint32(TPMPTFamilyIndicator),
PropertyCount: 1,
}
capabilityRsp, err := getCmd.Execute(thetpm)
if err != nil {
t.Fatalf("executing GetCapability: %v", err)
}

cmdParamsBytes, err := MarshalCommand(getCmd)
if err != nil {
t.Fatalf("MarshalCommand failed: %v", err)
}

unmarshalCmd, err := UnmarshalCommand[GetCapability](cmdParamsBytes)
if err != nil {
t.Fatalf("UnmarshalCommand failed: %v", err)
}

if !reflect.DeepEqual(getCmd, unmarshalCmd) {
t.Errorf("Commands do not match \nwant: %+v\ngot: %+v", getCmd, unmarshalCmd)
}

respParamsBytes, err := MarshalResponse(capabilityRsp)
if err != nil {
t.Fatalf("MarshalResponse failed: %v", err)
}

unmarshalRsp, err := UnmarshalResponse[GetCapabilityResponse](respParamsBytes)
if err != nil {
t.Fatalf("UnmarshalResponse failed: %v", err)
}
if !reflect.DeepEqual(capabilityRsp, unmarshalRsp) {
t.Errorf("Responses do not match \nwant: %+v\ngot: %+v", capabilityRsp, unmarshalRsp)
}
}
57 changes: 57 additions & 0 deletions tpm2/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,63 @@ func marshalParameter[R any](buf *bytes.Buffer, cmd Command[R, *R], i int) error
return marshal(buf, parm)
}

// unmarshalParameter will deserialize the given parameter of the command from the buffer.
// Returns an error if the value is not unmarshallable or if there's insufficient data.
func unmarshalParameter[C Command[R, *R], R any](buf *bytes.Buffer, cmd *C, i int) error {
numHandles := len(taggedMembers(reflect.ValueOf(*cmd), "handle", false))
if numHandles+i >= reflect.TypeOf(*cmd).NumField() {
return fmt.Errorf("invalid parameter index %v", i)
}
parm := reflect.ValueOf(cmd).Elem().Field(numHandles + i)
field := reflect.TypeOf(*cmd).Field(numHandles + i)

if hasTag(field, "optional") {
// Special case: Part 3 specifies some input/output
// parameters as "optional", which means that they are
// (2B-) sized fields that can be zero-length, even if the
// enclosed type has no legal empty serialization.
// When unmarshalling an optional field, test for zero size
// and skip if empty.
if buf.Len() >= 2 {
var checkBytes [2]byte
tempBuf := *buf
if err := binary.Read(&tempBuf, binary.BigEndian, &checkBytes); err != nil {
return fmt.Errorf("reading optional parameter size: %w", err)
}

if checkBytes == [2]byte{} {
// This is a nil pointer, consume the bytes and leave the field as nil
binary.Read(buf, binary.BigEndian, &checkBytes)
return nil
}
// Fall through to unmarshal the contents normally
} else {
return fmt.Errorf("not enough data for optional parameter %d", i)
}
}

// Handle nullable fields during unmarshaling
if parm.Kind() == reflect.Uint32 && hasTag(field, "nullable") {
var val uint32
if err := binary.Read(buf, binary.BigEndian, &val); err != nil {
return fmt.Errorf("reading nullable uint32 parameter: %w", err)
}
// TPMRHNull is the default for nullable uint32 fields
parm.SetUint(uint64(val))
return nil
} else if parm.Kind() == reflect.Uint16 && hasTag(field, "nullable") {
var val uint16
if err := binary.Read(buf, binary.BigEndian, &val); err != nil {
return fmt.Errorf("reading nullable uint16 parameter: %w", err)
}
// TPMAlgNull is the default for nullable uint16 fields
parm.SetUint(uint64(val))
return nil
}

return unmarshal(buf, parm)
}

// cmdParameters returns the parameters area of the command.
// The first parameter may be encrypted by one of the sessions.
func cmdParameters[R any](cmd Command[R, *R], sess []Session) ([]byte, error) {
Expand Down
143 changes: 143 additions & 0 deletions tpm2/test/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,146 @@ func TestAuditSession(t *testing.T) {
}

}

func TestReplayAuditSession(t *testing.T) {
thetpm, err := simulator.OpenSimulator()
if err != nil {
t.Fatalf("could not connect to TPM simulator: %v", err)
}
defer thetpm.Close()

// Create the audit session
sess, cleanup, err := HMACSession(thetpm, TPMAlgSHA256, 16, Audit())
if err != nil {
t.Fatalf("%v", err)
}
defer cleanup()

// Create the AK for audit
createAKCmd := CreatePrimary{
PrimaryHandle: TPMRHOwner,
InPublic: New2B(TPMTPublic{
Type: TPMAlgECC,
NameAlg: TPMAlgSHA256,
ObjectAttributes: TPMAObject{
FixedTPM: true,
STClear: false,
FixedParent: true,
SensitiveDataOrigin: true,
UserWithAuth: true,
AdminWithPolicy: false,
NoDA: true,
EncryptedDuplication: false,
Restricted: true,
Decrypt: false,
SignEncrypt: true,
},
Parameters: NewTPMUPublicParms(
TPMAlgECC,
&TPMSECCParms{
Scheme: TPMTECCScheme{
Scheme: TPMAlgECDSA,
Details: NewTPMUAsymScheme(
TPMAlgECDSA,
&TPMSSigSchemeECDSA{
HashAlg: TPMAlgSHA256,
},
),
},
CurveID: TPMECCNistP256,
},
),
},
),
}
createAKRsp, err := createAKCmd.Execute(thetpm)
if err != nil {
t.Fatalf("%v", err)
}
defer func() {
// Flush the AK
flush := FlushContext{FlushHandle: createAKRsp.ObjectHandle}
if _, err := flush.Execute(thetpm); err != nil {
t.Errorf("%v", err)
}
}()

audit, err := NewAudit(TPMAlgSHA256)
if err != nil {
t.Fatalf("%v", err)
}

// mimic an audit log by storing commands and responses
auditLogs := struct {
commands [][]byte
responses [][]byte
}{}

// Call GetCapability a bunch of times with the audit session and make sure it extends like
// we expect it to.
props := []TPMPT{
TPMPTFamilyIndicator,
TPMPTLevel,
TPMPTRevision,
TPMPTDayofYear,
TPMPTYear,
TPMPTManufacturer,
}
for _, prop := range props {
getCmd := GetCapability{
Capability: TPMCapTPMProperties,
Property: uint32(prop),
PropertyCount: 1,
}
getRsp, err := getCmd.Execute(thetpm, sess)
if err != nil {
t.Fatalf("%v", err)
}
cmdBytes, rspBytes, err := MarshalCommandResponse(getCmd, getRsp)
if err != nil {
t.Fatalf("%v", err)
}
auditLogs.commands = append(auditLogs.commands, cmdBytes)
auditLogs.responses = append(auditLogs.responses, rspBytes)
}
// Get the audit digest signed by the AK
getAuditCmd := GetSessionAuditDigest{
PrivacyAdminHandle: TPMRHEndorsement,
SignHandle: NamedHandle{
Handle: createAKRsp.ObjectHandle,
Name: createAKRsp.Name,
},
SessionHandle: sess.Handle(),
QualifyingData: TPM2BData{Buffer: []byte("foobar")},
}
getAuditRsp, err := getAuditCmd.Execute(thetpm)
if err != nil {
t.Fatalf("%v", err)
}
attest, err := getAuditRsp.AuditInfo.Contents()
if err != nil {
t.Fatalf("%v", err)
}
aud, err := attest.Attested.SessionAudit()
if err != nil {
t.Fatalf("%v", err)
}
want := aud.SessionDigest.Buffer

for i := range len(props) {
cmd, rsp, err := UnmarshalCommandResponse[GetCapability, GetCapabilityResponse](
auditLogs.commands[i],
auditLogs.responses[i],
)
if err != nil {
t.Fatalf("%v", err)
}
if err := AuditCommand(audit, cmd, rsp); err != nil {
t.Fatalf("%v", err)
}
}
got := audit.Digest()
if !bytes.Equal(want, got) {
t.Errorf("unexpected audit value:\ngot %x\nwant %x", got, want)
}
}