diff --git a/channeldb/migration/lnwire21/msat.go b/channeldb/migration/lnwire21/msat.go index 7473d72c82..47c6762850 100644 --- a/channeldb/migration/lnwire21/msat.go +++ b/channeldb/migration/lnwire21/msat.go @@ -2,8 +2,10 @@ package lnwire import ( "fmt" + "io" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -49,3 +51,39 @@ func (m MilliSatoshi) String() string { } // TODO(roasbeef): extend with arithmetic operations? + +// Record returns a TLV record that can be used to encode/decode a MilliSatoshi +// to/from a TLV stream. +func (m *MilliSatoshi) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, m, tlv.SizeBigSize(m), encodeMilliSatoshis, + decodeMilliSatoshis, + ) +} +func encodeMilliSatoshis(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*MilliSatoshi); ok { + bigSize := uint64(*v) + + return tlv.EBigSize(w, &bigSize, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.MilliSatoshi") +} + +func decodeMilliSatoshis(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*MilliSatoshi); ok { + var bigSize uint64 + err := tlv.DBigSize(r, &bigSize, buf, l) + if err != nil { + return err + } + + *v = MilliSatoshi(bigSize) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.MilliSatoshi", l, l) +} diff --git a/channeldb/migration/lnwire21/true_boolean.go b/channeldb/migration/lnwire21/true_boolean.go new file mode 100644 index 0000000000..3cde34263c --- /dev/null +++ b/channeldb/migration/lnwire21/true_boolean.go @@ -0,0 +1,37 @@ +package lnwire + +import ( + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TrueBoolean is a record that indicates true or false using the presence of +// the record. If the record is absent, it indicates false. If it is present, +// it indicates true. +type TrueBoolean struct{} + +// Record returns the tlv record for the boolean entry. +func (b *TrueBoolean) Record() tlv.Record { + return tlv.MakeStaticRecord( + 0, b, 0, booleanEncoder, booleanDecoder, + ) +} + +func booleanEncoder(_ io.Writer, val interface{}, _ *[8]byte) error { + if _, ok := val.(*TrueBoolean); ok { + return nil + } + + return tlv.NewTypeForEncodingErr(val, "TrueBoolean") +} + +func booleanDecoder(_ io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if _, ok := val.(*TrueBoolean); ok && (l == 0 || l == 1) { + return nil + } + + return tlv.NewTypeForEncodingErr(val, "TrueBoolean") +} diff --git a/channeldb/migration32/migration_test.go b/channeldb/migration32/migration_test.go index 1ce6016ed4..703c32a72c 100644 --- a/channeldb/migration32/migration_test.go +++ b/channeldb/migration32/migration_test.go @@ -9,6 +9,7 @@ import ( lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/channeldb/migtest" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -24,175 +25,213 @@ var ( _ = pubKeyY.SetByteSlice(pubkeyBytes) pubkey = btcec.NewPublicKey(new(btcec.FieldVal).SetInt(4), pubKeyY) - paymentResultCommon1 = paymentResultCommon{ + customRecord = map[uint64][]byte{ + 65536: {4, 2, 2}, + } + + resultOld1 = paymentResultOld{ id: 0, timeFwd: time.Unix(0, 1), timeReply: time.Unix(0, 2), success: false, failureSourceIdx: &failureIndex, failure: &lnwire.FailFeeInsufficient{}, + route: &Route{ + TotalTimeLock: 100, + TotalAmount: 400, + SourcePubKey: testPub, + Hops: []*Hop{ + // A hop with MPP, AMP and custom + // records. + { + PubKeyBytes: testPub, + ChannelID: 100, + OutgoingTimeLock: 300, + AmtToForward: 500, + MPP: &MPP{ + paymentAddr: [32]byte{4, 5}, + totalMsat: 900, + }, + AMP: &{ + rootShare: [32]byte{0, 0}, + setID: [32]byte{5, 5, 5}, + childIndex: 90, + }, + CustomRecords: customRecord, + Metadata: []byte{6, 7, 7}, + }, + // A legacy hop. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + LegacyPayload: true, + }, + // A hop with a blinding key. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + BlindingPoint: pubkey, + EncryptedData: []byte{1, 2, 3}, + TotalAmtMsat: 600, + }, + // A hop with a blinding key and custom + // records. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + CustomRecords: customRecord, + BlindingPoint: pubkey, + EncryptedData: []byte{1, 2, 3}, + TotalAmtMsat: 600, + }, + }, + }, } - paymentResultCommon2 = paymentResultCommon{ + resultOld2 = paymentResultOld{ id: 2, timeFwd: time.Unix(0, 4), timeReply: time.Unix(0, 7), success: true, + route: &Route{ + TotalTimeLock: 101, + TotalAmount: 401, + SourcePubKey: testPub2, + Hops: []*Hop{ + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + BlindingPoint: pubkey, + EncryptedData: []byte{1, 2, 3}, + CustomRecords: customRecord, + TotalAmtMsat: 600, + }, + }, + }, } -) -// TestMigrateMCRouteSerialisation tests that the MigrateMCRouteSerialisation -// migration function correctly migrates the MC store from using the old route -// encoding to using the newer, more minimal route encoding. -func TestMigrateMCRouteSerialisation(t *testing.T) { - customRecord := map[uint64][]byte{ - 65536: {4, 2, 2}, + //nolint:lll + resultNew1Hop1 = &mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](100), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](500), + hasCustomRecords: tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](), + ), } - resultsOld := []*paymentResultOld{ - { - paymentResultCommon: paymentResultCommon1, - route: &Route{ - TotalTimeLock: 100, - TotalAmount: 400, - SourcePubKey: testPub, - Hops: []*Hop{ - // A hop with MPP, AMP and custom - // records. - { - PubKeyBytes: testPub, - ChannelID: 100, - OutgoingTimeLock: 300, - AmtToForward: 500, - MPP: &MPP{ - paymentAddr: [32]byte{ - 4, 5, - }, - totalMsat: 900, - }, - AMP: &{ - rootShare: [32]byte{ - 0, 0, - }, - setID: [32]byte{ - 5, 5, 5, - }, - childIndex: 90, - }, - CustomRecords: customRecord, - Metadata: []byte{6, 7, 7}, - }, - // A legacy hop. - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - LegacyPayload: true, - }, - // A hop with a blinding key. - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - BlindingPoint: pubkey, - EncryptedData: []byte{ - 1, 2, 3, - }, - TotalAmtMsat: 600, - }, - // A hop with a blinding key and custom - // records. - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - CustomRecords: customRecord, - BlindingPoint: pubkey, - EncryptedData: []byte{ - 1, 2, 3, - }, - TotalAmtMsat: 600, - }, - }, - }, - }, - { - paymentResultCommon: paymentResultCommon2, - route: &Route{ - TotalTimeLock: 101, - TotalAmount: 401, - SourcePubKey: testPub2, - Hops: []*Hop{ - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - BlindingPoint: pubkey, - EncryptedData: []byte{ - 1, 2, 3, - }, - TotalAmtMsat: 600, - }, - }, - }, - }, + //nolint:lll + resultNew1Hop2 = &mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), } - expectedResultsNew := []*paymentResultNew{ - { - paymentResultCommon: paymentResultCommon1, - route: &mcRoute{ - sourcePubKey: testPub, - totalAmount: 400, - hops: []*mcHop{ - { - channelID: 100, - pubKeyBytes: testPub, - amtToFwd: 500, - hasCustomRecords: true, - }, - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - }, - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - hasBlindingPoint: true, - }, - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - hasBlindingPoint: true, - hasCustomRecords: true, - }, - }, - }, - }, - { - paymentResultCommon: paymentResultCommon2, - route: &mcRoute{ - sourcePubKey: testPub2, - totalAmount: 401, - hops: []*mcHop{ - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - hasBlindingPoint: true, - }, - }, - }, - }, + //nolint:lll + resultNew1Hop3 = &mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), + hasBlindingPoint: tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](), + ), } + //nolint:lll + resultNew1Hop4 = &mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), + hasCustomRecords: tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](), + ), + hasBlindingPoint: tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](), + ), + } + + //nolint:lll + resultNew2Hop1 = &mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), + hasCustomRecords: tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](), + ), + hasBlindingPoint: tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](), + ), + } + + //nolint:lll + resultNew1 = paymentResultNew{ + id: 0, + timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint64(time.Unix(0, 1).UnixNano()), + ), + timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint64(time.Unix(0, 2).UnixNano()), + ), + failure: tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3]( + *newPaymentFailure( + &failureIndex, + &lnwire.FailFeeInsufficient{}, + ), + ), + ), + route: tlv.NewRecordT[tlv.TlvType2](mcRoute{ + sourcePubKey: tlv.NewRecordT[tlv.TlvType0](testPub), + totalAmount: tlv.NewRecordT[tlv.TlvType1, lnwire.MilliSatoshi](400), + hops: tlv.NewRecordT[tlv.TlvType2, mcHops](mcHops{ + resultNew1Hop1, + resultNew1Hop2, + resultNew1Hop3, + resultNew1Hop4, + }), + }), + } + + //nolint:lll + resultNew2 = paymentResultNew{ + id: 2, + timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64]( + uint64(time.Unix(0, 4).UnixNano()), + ), + timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1, uint64]( + uint64(time.Unix(0, 7).UnixNano()), + ), + route: tlv.NewRecordT[tlv.TlvType2](mcRoute{ + sourcePubKey: tlv.NewRecordT[tlv.TlvType0](testPub2), + totalAmount: tlv.NewRecordT[tlv.TlvType1, lnwire.MilliSatoshi](401), + hops: tlv.NewRecordT[tlv.TlvType2](mcHops{ + resultNew2Hop1, + }), + }), + } +) + +// TestMigrateMCRouteSerialisation tests that the MigrateMCRouteSerialisation +// migration function correctly migrates the MC store from using the old route +// encoding to using the newer, more minimal route encoding. +func TestMigrateMCRouteSerialisation(t *testing.T) { + var ( + resultsOld = []*paymentResultOld{ + &resultOld1, &resultOld2, + } + expectedResultsNew = []*paymentResultNew{ + &resultNew1, &resultNew2, + } + ) + // Prime the database with some mission control data that uses the // old route encoding. before := func(tx kvdb.RwTx) error { diff --git a/channeldb/migration32/mission_control_store.go b/channeldb/migration32/mission_control_store.go index 3953cd1f25..3ac9d6114c 100644 --- a/channeldb/migration32/mission_control_store.go +++ b/channeldb/migration32/mission_control_store.go @@ -8,6 +8,8 @@ import ( "github.com/btcsuite/btcd/wire" lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -22,30 +24,22 @@ var ( resultsKey = []byte("missioncontrol-results") ) -// paymentResultCommon holds the fields that are shared by the old and new -// payment result encoding. -type paymentResultCommon struct { +// paymentResultOld is the information that becomes available when a payment +// attempt completes. +type paymentResultOld struct { id uint64 timeFwd, timeReply time.Time + route *Route success bool failureSourceIdx *int failure lnwire.FailureMessage } -// paymentResultOld is the information that becomes available when a payment -// attempt completes. -type paymentResultOld struct { - paymentResultCommon - route *Route -} - // deserializeOldResult deserializes a payment result using the old encoding. func deserializeOldResult(k, v []byte) (*paymentResultOld, error) { // Parse payment id. result := paymentResultOld{ - paymentResultCommon: paymentResultCommon{ - id: byteOrder.Uint64(k[8:]), - }, + id: byteOrder.Uint64(k[8:]), } r := bytes.NewReader(v) @@ -99,67 +93,563 @@ func deserializeOldResult(k, v []byte) (*paymentResultOld, error) { // convertPaymentResult converts a paymentResultOld to a paymentResultNew. func convertPaymentResult(old *paymentResultOld) *paymentResultNew { - return &paymentResultNew{ - paymentResultCommon: old.paymentResultCommon, - route: extractMCRoute(old.route), + var failure *paymentFailure + if !old.success { + failure = newPaymentFailure(old.failureSourceIdx, old.failure) } + + return newPaymentResult( + old.id, extractMCRoute(old.route), old.timeFwd, old.timeReply, + failure, + ) +} + +// newPaymentResult constructs a new paymentResult. +func newPaymentResult(id uint64, rt *mcRoute, timeFwd, timeReply time.Time, + failure *paymentFailure) *paymentResultNew { + + result := &paymentResultNew{ + id: id, + timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint64(timeFwd.UnixNano()), + ), + timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint64(timeReply.UnixNano()), + ), + route: tlv.NewRecordT[tlv.TlvType2](*rt), + } + + if failure != nil { + result.failure = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3](*failure), + ) + } + + return result } // paymentResultNew is the information that becomes available when a payment // attempt completes. type paymentResultNew struct { - paymentResultCommon - route *mcRoute + id uint64 + timeFwd tlv.RecordT[tlv.TlvType0, uint64] + timeReply tlv.RecordT[tlv.TlvType1, uint64] + route tlv.RecordT[tlv.TlvType2, mcRoute] + + // failure holds information related to the failure of a payment. The + // presence of this record indicates a payment failure. The absence of + // this record indicates a successful payment. + failure tlv.OptionalRecordT[tlv.TlvType3, paymentFailure] +} + +// paymentFailure represents the presence of a payment failure. It may or may +// not include additional information about said failure. +type paymentFailure struct { + info tlv.OptionalRecordT[tlv.TlvType0, paymentFailureInfo] +} + +// newPaymentFailure constructs a new paymentFailure struct. If the source +// index is nil, then an empty paymentFailure is returned. This represents a +// failure with unknown details. Otherwise, the index and failure message are +// used to populate the info field of the paymentFailure. +func newPaymentFailure(sourceIdx *int, + failureMsg lnwire.FailureMessage) *paymentFailure { + + if sourceIdx == nil { + return &paymentFailure{} + } + + info := paymentFailureInfo{ + sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint8(*sourceIdx), + ), + msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}), + } + + return &paymentFailure{ + info: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType0](info)), + } +} + +// Record returns a TLV record that can be used to encode/decode a +// paymentFailure to/from a TLV stream. +func (r *paymentFailure) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodePaymentFailure(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodePaymentFailure, decodePaymentFailure, + ) +} + +func encodePaymentFailure(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*paymentFailure); ok { + var recordProducers []tlv.RecordProducer + v.info.WhenSome( + func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) { + recordProducers = append(recordProducers, &r) + }, + ) + + return lnwire.EncodeRecordsTo( + w, lnwire.ProduceRecordsSorted(recordProducers...), + ) + } + + return tlv.NewTypeForEncodingErr(val, "routing.paymentFailure") +} + +func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*paymentFailure); ok { + var h paymentFailure + + info := tlv.ZeroRecordT[tlv.TlvType0, paymentFailureInfo]() + typeMap, err := lnwire.DecodeRecords( + r, lnwire.ProduceRecordsSorted(&info)..., + ) + if err != nil { + return err + } + + if _, ok := typeMap[h.info.TlvType()]; ok { + h.info = tlv.SomeRecordT(info) + } + + *v = h + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.paymentFailure", l, l) +} + +// paymentFailureInfo holds additional information about a payment failure. +type paymentFailureInfo struct { + sourceIdx tlv.RecordT[tlv.TlvType0, uint8] + msg tlv.RecordT[tlv.TlvType1, failureMessage] +} + +// Record returns a TLV record that can be used to encode/decode a +// paymentFailureInfo to/from a TLV stream. +func (r *paymentFailureInfo) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodePaymentFailureInfo(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodePaymentFailureInfo, + decodePaymentFailureInfo, + ) +} + +func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*paymentFailureInfo); ok { + return lnwire.EncodeRecordsTo( + w, lnwire.ProduceRecordsSorted( + &v.sourceIdx, &v.msg, + ), + ) + } + + return tlv.NewTypeForEncodingErr(val, "routing.paymentFailureInfo") +} + +func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*paymentFailureInfo); ok { + var h paymentFailureInfo + + _, err := lnwire.DecodeRecords( + r, + lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)..., + ) + if err != nil { + return err + } + + *v = h + + return nil + } + + return tlv.NewTypeForDecodingErr( + val, "routing.paymentFailureInfo", l, l, + ) +} + +type failureMessage struct { + lnwire.FailureMessage +} + +// Record returns a TLV record that can be used to encode/decode a list of +// failureMessage to/from a TLV stream. +func (r *failureMessage) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeFailureMessage(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodeFailureMessage, decodeFailureMessage, + ) +} + +func encodeFailureMessage(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*failureMessage); ok { + var b bytes.Buffer + err := lnwire.EncodeFailureMessage(&b, v.FailureMessage, 0) + if err != nil { + return err + } + + _, err = w.Write(b.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "routing.failureMessage") +} + +func decodeFailureMessage(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*failureMessage); ok { + msg, err := lnwire.DecodeFailureMessage(r, 0) + if err != nil { + return err + } + + *v = failureMessage{ + FailureMessage: msg, + } + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.failureMessage", l, l) } // extractMCRoute extracts the fields required by MC from the Route struct to // create the more minimal mcRoute struct. -func extractMCRoute(route *Route) *mcRoute { +func extractMCRoute(r *Route) *mcRoute { return &mcRoute{ - sourcePubKey: route.SourcePubKey, - totalAmount: route.TotalAmount, - hops: extractMCHops(route.Hops), + sourcePubKey: tlv.NewRecordT[tlv.TlvType0](r.SourcePubKey), + totalAmount: tlv.NewRecordT[tlv.TlvType1](r.TotalAmount), + hops: tlv.NewRecordT[tlv.TlvType2]( + extractMCHops(r.Hops), + ), } } // extractMCHops extracts the Hop fields that MC actually uses from a slice of // Hops. -func extractMCHops(hops []*Hop) []*mcHop { - mcHops := make([]*mcHop, len(hops)) - for i, hop := range hops { - mcHops[i] = extractMCHop(hop) - } - - return mcHops +func extractMCHops(hops []*Hop) mcHops { + return fn.Map(extractMCHop, hops) } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. func extractMCHop(hop *Hop) *mcHop { - return &mcHop{ - channelID: hop.ChannelID, - pubKeyBytes: hop.PubKeyBytes, - amtToFwd: hop.AmtToForward, - hasBlindingPoint: hop.BlindingPoint != nil, - hasCustomRecords: len(hop.CustomRecords) > 0, + h := mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64]( + hop.ChannelID, + ), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex]( + hop.PubKeyBytes, + ), + amtToFwd: tlv.NewRecordT[tlv.TlvType2, lnwire.MilliSatoshi]( + hop.AmtToForward, + ), + } + + if hop.BlindingPoint != nil { + h.hasBlindingPoint = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3, lnwire.TrueBoolean]( + lnwire.TrueBoolean{}, + ), + ) + } + + if len(hop.CustomRecords) != 0 { + h.hasCustomRecords = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + lnwire.TrueBoolean{}, + ), + ) } + + return &h } // mcRoute holds the bare minimum info about a payment attempt route that MC // requires. type mcRoute struct { - sourcePubKey Vertex - totalAmount lnwire.MilliSatoshi - hops []*mcHop + sourcePubKey tlv.RecordT[tlv.TlvType0, Vertex] + totalAmount tlv.RecordT[tlv.TlvType1, lnwire.MilliSatoshi] + hops tlv.RecordT[tlv.TlvType2, mcHops] +} + +// Record returns a TLV record that can be used to encode/decode an mcRoute +// to/from a TLV stream. +func (r *mcRoute) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeMCRoute(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodeMCRoute, decodeMCRoute, + ) +} + +func encodeMCRoute(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*mcRoute); ok { + return serializeRoute(w, v) + } + + return tlv.NewTypeForEncodingErr(val, "routing.mcRoute") +} + +func decodeMCRoute(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if v, ok := val.(*mcRoute); ok { + route, err := deserializeRoute(io.LimitReader(r, int64(l))) + if err != nil { + return err + } + + *v = *route + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.mcRoute", l, l) +} + +// mcHops is a list of mcHop records. +type mcHops []*mcHop + +// Record returns a TLV record that can be used to encode/decode a list of +// mcHop to/from a TLV stream. +func (h *mcHops) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeMCHops(&b, h, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, h, recordSize, encodeMCHops, decodeMCHops, + ) +} + +func encodeMCHops(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*mcHops); ok { + // Encode the number of hops as a var int. + if err := tlv.WriteVarInt(w, uint64(len(*v)), buf); err != nil { + return err + } + + // With that written out, we'll now encode the entries + // themselves as a sub-TLV record, which includes its _own_ + // inner length prefix. + for _, hop := range *v { + var hopBytes bytes.Buffer + if err := serializeNewHop(&hopBytes, hop); err != nil { + return err + } + + // We encode the record with a varint length followed by + // the _raw_ TLV bytes. + tlvLen := uint64(len(hopBytes.Bytes())) + if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil { + return err + } + + if _, err := w.Write(hopBytes.Bytes()); err != nil { + return err + } + } + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "routing.mcHops") +} + +func decodeMCHops(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*mcHops); ok { + // First, we'll decode the varint that encodes how many hops + // are encoded in the stream. + numHops, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Now that we know how many records we'll need to read, we can + // iterate and read them all out in series. + for i := uint64(0); i < numHops; i++ { + // Read out the varint that encodes the size of this + // inner TLV record. + hopSize, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Using this information, we'll create a new limited + // reader that'll return an EOF once the end has been + // reached so the stream stops consuming bytes. + innerTlvReader := &io.LimitedReader{ + R: r, + N: int64(hopSize), + } + + hop, err := deserializeNewHop(innerTlvReader) + if err != nil { + return err + } + + *v = append(*v, hop) + } + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.mcHops", l, l) +} + +// serializeRoute serializes a mcRoute and writes the resulting bytes to the +// given io.Writer. +func serializeRoute(w io.Writer, r *mcRoute) error { + records := lnwire.ProduceRecordsSorted( + &r.sourcePubKey, + &r.totalAmount, + &r.hops, + ) + + return lnwire.EncodeRecordsTo(w, records) +} + +// deserializeRoute deserializes the mcRoute from the given io.Reader. +func deserializeRoute(r io.Reader) (*mcRoute, error) { + var rt mcRoute + records := lnwire.ProduceRecordsSorted( + &rt.sourcePubKey, + &rt.totalAmount, + &rt.hops, + ) + + _, err := lnwire.DecodeRecords(r, records...) + if err != nil { + return nil, err + } + + return &rt, nil +} + +// deserializeNewHop deserializes the mcHop from the given io.Reader. +func deserializeNewHop(r io.Reader) (*mcHop, error) { + var ( + h mcHop + blinding = tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean]() + custom = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]() + ) + records := lnwire.ProduceRecordsSorted( + &h.channelID, + &h.pubKeyBytes, + &h.amtToFwd, + &blinding, + &custom, + ) + + typeMap, err := lnwire.DecodeRecords(r, records...) + if err != nil { + return nil, err + } + + if _, ok := typeMap[h.hasBlindingPoint.TlvType()]; ok { + h.hasBlindingPoint = tlv.SomeRecordT(blinding) + } + + if _, ok := typeMap[h.hasCustomRecords.TlvType()]; ok { + h.hasCustomRecords = tlv.SomeRecordT(custom) + } + + return &h, nil +} + +// serializeNewHop serializes a mcHop and writes the resulting bytes to the +// given io.Writer. +func serializeNewHop(w io.Writer, h *mcHop) error { + recordProducers := []tlv.RecordProducer{ + &h.channelID, + &h.pubKeyBytes, + &h.amtToFwd, + } + + h.hasBlindingPoint.WhenSome(func( + hasBlinding tlv.RecordT[tlv.TlvType3, lnwire.TrueBoolean]) { + + recordProducers = append(recordProducers, &hasBlinding) + }) + + h.hasCustomRecords.WhenSome(func( + hasCustom tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) { + + recordProducers = append(recordProducers, &hasCustom) + }) + + return lnwire.EncodeRecordsTo( + w, lnwire.ProduceRecordsSorted(recordProducers...), + ) } // mcHop holds the bare minimum info about a payment attempt route hop that MC // requires. type mcHop struct { - channelID uint64 - pubKeyBytes Vertex - amtToFwd lnwire.MilliSatoshi - hasBlindingPoint bool - hasCustomRecords bool + channelID tlv.RecordT[tlv.TlvType0, uint64] + pubKeyBytes tlv.RecordT[tlv.TlvType1, Vertex] + amtToFwd tlv.RecordT[tlv.TlvType2, lnwire.MilliSatoshi] + hasBlindingPoint tlv.OptionalRecordT[tlv.TlvType3, lnwire.TrueBoolean] + hasCustomRecords tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean] } // serializeOldResult serializes a payment result and returns a key and value @@ -225,48 +715,30 @@ func getResultKeyOld(rp *paymentResultOld) []byte { // serializeNewResult serializes a payment result and returns a key and value // byte slice to insert into the bucket. func serializeNewResult(rp *paymentResultNew) ([]byte, []byte, error) { - // Write timestamps, success status, failure source index and route. - var b bytes.Buffer - - var dbFailureSourceIdx int32 - if rp.failureSourceIdx == nil { - dbFailureSourceIdx = unknownFailureSourceIdx - } else { - dbFailureSourceIdx = int32(*rp.failureSourceIdx) + recordProducers := []tlv.RecordProducer{ + &rp.timeFwd, + &rp.timeReply, + &rp.route, } - err := WriteElements( - &b, - uint64(rp.timeFwd.UnixNano()), - uint64(rp.timeReply.UnixNano()), - rp.success, dbFailureSourceIdx, + rp.failure.WhenSome( + func(failure tlv.RecordT[tlv.TlvType3, paymentFailure]) { + recordProducers = append(recordProducers, &failure) + }, ) - if err != nil { - return nil, nil, err - } - if err := serializeMCRoute(&b, rp.route); err != nil { - return nil, nil, err - } + // Compose key that identifies this result. + key := getResultKeyNew(rp) - // Write failure. If there is no failure message, write an empty - // byte slice. - var failureBytes bytes.Buffer - if rp.failure != nil { - err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0) - if err != nil { - return nil, nil, err - } - } - err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes()) + var buff bytes.Buffer + err := lnwire.EncodeRecordsTo( + &buff, lnwire.ProduceRecordsSorted(recordProducers...), + ) if err != nil { return nil, nil, err } - // Compose key that identifies this result. - key := getResultKeyNew(rp) - - return key, b.Bytes(), nil + return key, buff.Bytes(), err } // getResultKeyNew returns a byte slice representing a unique key for this @@ -278,43 +750,9 @@ func getResultKeyNew(rp *paymentResultNew) []byte { // key. This allows importing mission control data from an external // source without key collisions and keeps the records sorted // chronologically. - byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) + byteOrder.PutUint64(keyBytes[:], rp.timeReply.Val) byteOrder.PutUint64(keyBytes[8:], rp.id) - copy(keyBytes[16:], rp.route.sourcePubKey[:]) + copy(keyBytes[16:], rp.route.Val.sourcePubKey.Val[:]) return keyBytes[:] } - -// serializeMCRoute serializes an mcRoute and writes the bytes to the given -// io.Writer. -func serializeMCRoute(w io.Writer, r *mcRoute) error { - if err := WriteElements( - w, r.totalAmount, r.sourcePubKey[:], - ); err != nil { - return err - } - - if err := WriteElements(w, uint32(len(r.hops))); err != nil { - return err - } - - for _, h := range r.hops { - if err := serializeNewHop(w, h); err != nil { - return err - } - } - - return nil -} - -// serializeMCRoute serializes an mcHop and writes the bytes to the given -// io.Writer. -func serializeNewHop(w io.Writer, h *mcHop) error { - return WriteElements(w, - h.pubKeyBytes[:], - h.channelID, - h.amtToFwd, - h.hasBlindingPoint, - h.hasCustomRecords, - ) -} diff --git a/channeldb/migration32/route.go b/channeldb/migration32/route.go index a4d40a45cb..a35338e504 100644 --- a/channeldb/migration32/route.go +++ b/channeldb/migration32/route.go @@ -29,6 +29,32 @@ const VertexSize = 33 // public key. type Vertex [VertexSize]byte +// Record returns a TLV record that can be used to encode/decode a Vertex +// to/from a TLV stream. +func (v *Vertex) Record() tlv.Record { + return tlv.MakeStaticRecord( + 0, v, VertexSize, encodeVertex, decodeVertex, + ) +} + +func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*Vertex); ok { + _, err := w.Write(b[:]) + return err + } + + return tlv.NewTypeForEncodingErr(val, "Vertex") +} + +func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*Vertex); ok { + _, err := io.ReadFull(r, b[:]) + return err + } + + return tlv.NewTypeForDecodingErr(val, "Vertex", l, VertexSize) +} + // Route represents a path through the channel graph which runs over one or // more channels in succession. This struct carries all the information // required to craft the Sphinx onion packet, and send the payment along the diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index d98bc534f8..6337dd89fb 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -122,7 +122,8 @@ * [Migrate the mission control store](https://github.com/lightningnetwork/lnd/pull/8911) to use a more - minimal encoding for payment attempt routes. + minimal encoding for payment attempt routes as well as use [pure TLV + encoding](https://github.com/lightningnetwork/lnd/pull/9167). * [Migrate the mission control store](https://github.com/lightningnetwork/lnd/pull/9001) so that results are diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index 79a76aad61..5c6d240958 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -411,7 +411,7 @@ func decodeDisableFlags(r io.Reader, val interface{}, buf *[8]byte, } // TrueBoolean is a record that indicates true or false using the presence of -// the record. If the record is absent, it indicates false. If it is presence, +// the record. If the record is absent, it indicates false. If it is present, // it indicates true. type TrueBoolean struct{} diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index be8ce7c9f6..20aa049073 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -1,8 +1,10 @@ package routing import ( + "bytes" "errors" "fmt" + "io" "sync" "time" @@ -16,6 +18,7 @@ import ( "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -261,12 +264,39 @@ type MissionControlPairSnapshot struct { // paymentResult is the information that becomes available when a payment // attempt completes. type paymentResult struct { - id uint64 - timeFwd, timeReply time.Time - route *mcRoute - success bool - failureSourceIdx *int - failure lnwire.FailureMessage + id uint64 + timeFwd tlv.RecordT[tlv.TlvType0, uint64] + timeReply tlv.RecordT[tlv.TlvType1, uint64] + route tlv.RecordT[tlv.TlvType2, mcRoute] + + // failure holds information related to the failure of a payment. The + // presence of this record indicates a payment failure. The absence of + // this record indicates a successful payment. + failure tlv.OptionalRecordT[tlv.TlvType3, paymentFailure] +} + +// newPaymentResult constructs a new paymentResult. +func newPaymentResult(id uint64, rt *mcRoute, timeFwd, timeReply time.Time, + failure *paymentFailure) *paymentResult { + + result := &paymentResult{ + id: id, + timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint64(timeFwd.UnixNano()), + ), + timeReply: tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint64(timeReply.UnixNano()), + ), + route: tlv.NewRecordT[tlv.TlvType2](*rt), + } + + if failure != nil { + result.failure = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3](*failure), + ) + } + + return result } // NewMissionController returns a new instance of MissionController. @@ -590,15 +620,10 @@ func (m *MissionControl) ReportPaymentFail(paymentID uint64, rt *route.Route, timestamp := m.cfg.clock.Now() - result := &paymentResult{ - success: false, - timeFwd: timestamp, - timeReply: timestamp, - id: paymentID, - failureSourceIdx: failureSourceIdx, - failure: failure, - route: extractMCRoute(rt), - } + result := newPaymentResult( + paymentID, extractMCRoute(rt), timestamp, timestamp, + newPaymentFailure(failureSourceIdx, failure), + ) return m.processPaymentResult(result) } @@ -610,15 +635,12 @@ func (m *MissionControl) ReportPaymentSuccess(paymentID uint64, timestamp := m.cfg.clock.Now() - result := &paymentResult{ - timeFwd: timestamp, - timeReply: timestamp, - id: paymentID, - success: true, - route: extractMCRoute(rt), - } + result := newPaymentResult( + paymentID, extractMCRoute(rt), timestamp, timestamp, nil, + ) _, err := m.processPaymentResult(result) + return err } @@ -646,14 +668,11 @@ func (m *MissionControl) applyPaymentResult( result *paymentResult) *channeldb.FailureReason { // Interpret result. - i := interpretResult( - result.route, result.success, result.failureSourceIdx, - result.failure, - ) + i := interpretResult(&result.route.Val, result.failure.ValOpt()) if i.policyFailure != nil { if m.state.requestSecondChance( - result.timeReply, + time.Unix(0, int64(result.timeReply.Val)), i.policyFailure.From, i.policyFailure.To, ) { return nil @@ -681,7 +700,10 @@ func (m *MissionControl) applyPaymentResult( m.log.Debugf("Reporting node failure to Mission Control: "+ "node=%v", *i.nodeFailure) - m.state.setAllFail(*i.nodeFailure, result.timeReply) + m.state.setAllFail( + *i.nodeFailure, + time.Unix(0, int64(result.timeReply.Val)), + ) } for pair, pairResult := range i.pairResults { @@ -698,7 +720,9 @@ func (m *MissionControl) applyPaymentResult( } m.state.setLastPairResult( - pair.From, pair.To, result.timeReply, &pairResult, false, + pair.From, pair.To, + time.Unix(0, int64(result.timeReply.Val)), &pairResult, + false, ) } @@ -803,3 +827,158 @@ func (n *namespacedDB) purge() error { return err }, func() {}) } + +// paymentFailure represents the presence of a payment failure. It may or may +// not include additional information about said failure. +type paymentFailure struct { + info tlv.OptionalRecordT[tlv.TlvType0, paymentFailureInfo] +} + +// newPaymentFailure constructs a new paymentFailure struct. If the source +// index is nil, then an empty paymentFailure is returned. This represents a +// failure with unknown details. Otherwise, the index and failure message are +// used to populate the info field of the paymentFailure. +func newPaymentFailure(sourceIdx *int, + failureMsg lnwire.FailureMessage) *paymentFailure { + + if sourceIdx == nil { + return &paymentFailure{} + } + + info := paymentFailureInfo{ + sourceIdx: tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint8(*sourceIdx), + ), + msg: tlv.NewRecordT[tlv.TlvType1](failureMessage{failureMsg}), + } + + return &paymentFailure{ + info: tlv.SomeRecordT(tlv.NewRecordT[tlv.TlvType0](info)), + } +} + +// Record returns a TLV record that can be used to encode/decode a +// paymentFailure to/from a TLV stream. +func (r *paymentFailure) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodePaymentFailure(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodePaymentFailure, decodePaymentFailure, + ) +} + +func encodePaymentFailure(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*paymentFailure); ok { + var recordProducers []tlv.RecordProducer + v.info.WhenSome( + func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) { + recordProducers = append(recordProducers, &r) + }, + ) + + return lnwire.EncodeRecordsTo( + w, lnwire.ProduceRecordsSorted(recordProducers...), + ) + } + + return tlv.NewTypeForEncodingErr(val, "routing.paymentFailure") +} + +func decodePaymentFailure(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*paymentFailure); ok { + var h paymentFailure + + info := tlv.ZeroRecordT[tlv.TlvType0, paymentFailureInfo]() + typeMap, err := lnwire.DecodeRecords( + r, lnwire.ProduceRecordsSorted(&info)..., + ) + if err != nil { + return err + } + + if _, ok := typeMap[h.info.TlvType()]; ok { + h.info = tlv.SomeRecordT(info) + } + + *v = h + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.paymentFailure", l, l) +} + +// paymentFailureInfo holds additional information about a payment failure. +type paymentFailureInfo struct { + sourceIdx tlv.RecordT[tlv.TlvType0, uint8] + msg tlv.RecordT[tlv.TlvType1, failureMessage] +} + +// Record returns a TLV record that can be used to encode/decode a +// paymentFailureInfo to/from a TLV stream. +func (r *paymentFailureInfo) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodePaymentFailureInfo(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodePaymentFailureInfo, + decodePaymentFailureInfo, + ) +} + +func encodePaymentFailureInfo(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*paymentFailureInfo); ok { + return lnwire.EncodeRecordsTo( + w, lnwire.ProduceRecordsSorted( + &v.sourceIdx, &v.msg, + ), + ) + } + + return tlv.NewTypeForEncodingErr(val, "routing.paymentFailureInfo") +} + +func decodePaymentFailureInfo(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*paymentFailureInfo); ok { + var h paymentFailureInfo + + _, err := lnwire.DecodeRecords( + r, + lnwire.ProduceRecordsSorted(&h.sourceIdx, &h.msg)..., + ) + if err != nil { + return err + } + + *v = h + + return nil + } + + return tlv.NewTypeForDecodingErr( + val, "routing.paymentFailureInfo", l, l, + ) +} diff --git a/routing/missioncontrol_store.go b/routing/missioncontrol_store.go index c40f24697a..ceb4740e44 100644 --- a/routing/missioncontrol_store.go +++ b/routing/missioncontrol_store.go @@ -6,14 +6,12 @@ import ( "encoding/binary" "fmt" "io" - "math" "sync" "time" - "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -26,12 +24,6 @@ var ( byteOrder = binary.BigEndian ) -const ( - // unknownFailureSourceIdx is the database encoding of an unknown error - // source. - unknownFailureSourceIdx = -1 -) - // missionControlDB is an interface that defines the database methods that a // single missionControlStore has access to. It allows the missionControlStore // to be unaware of the overall DB structure and restricts its access to the DB @@ -168,190 +160,146 @@ func (b *missionControlStore) fetchAll() ([]*paymentResult, error) { // serializeResult serializes a payment result and returns a key and value byte // slice to insert into the bucket. func serializeResult(rp *paymentResult) ([]byte, []byte, error) { - // Write timestamps, success status, failure source index and route. - var b bytes.Buffer - - var dbFailureSourceIdx int32 - if rp.failureSourceIdx == nil { - dbFailureSourceIdx = unknownFailureSourceIdx - } else { - dbFailureSourceIdx = int32(*rp.failureSourceIdx) + recordProducers := []tlv.RecordProducer{ + &rp.timeFwd, + &rp.timeReply, + &rp.route, } - err := channeldb.WriteElements( - &b, - uint64(rp.timeFwd.UnixNano()), - uint64(rp.timeReply.UnixNano()), - rp.success, dbFailureSourceIdx, + rp.failure.WhenSome( + func(failure tlv.RecordT[tlv.TlvType3, paymentFailure]) { + recordProducers = append(recordProducers, &failure) + }, ) - if err != nil { - return nil, nil, err - } - if err := serializeRoute(&b, rp.route); err != nil { - return nil, nil, err - } + // Compose key that identifies this result. + key := getResultKey(rp) - // Write failure. If there is no failure message, write an empty - // byte slice. - var failureBytes bytes.Buffer - if rp.failure != nil { - err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0) - if err != nil { - return nil, nil, err - } - } - err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes()) + var buff bytes.Buffer + err := lnwire.EncodeRecordsTo( + &buff, lnwire.ProduceRecordsSorted(recordProducers...), + ) if err != nil { return nil, nil, err } - // Compose key that identifies this result. - key := getResultKey(rp) - - return key, b.Bytes(), nil + return key, buff.Bytes(), err } -// deserializeRoute deserializes the mcRoute from the given io.Reader. -func deserializeRoute(r io.Reader) (*mcRoute, error) { - var rt mcRoute - if err := channeldb.ReadElements(r, &rt.totalAmount); err != nil { - return nil, err +// deserializeResult deserializes a payment result. +func deserializeResult(k, v []byte) (*paymentResult, error) { + // Parse payment id. + result := paymentResult{ + id: byteOrder.Uint64(k[8:]), } - var pub []byte - if err := channeldb.ReadElements(r, &pub); err != nil { - return nil, err + failure := tlv.ZeroRecordT[tlv.TlvType3, paymentFailure]() + recordProducers := []tlv.RecordProducer{ + &result.timeFwd, + &result.timeReply, + &result.route, + &failure, } - copy(rt.sourcePubKey[:], pub) - var numHops uint32 - if err := channeldb.ReadElements(r, &numHops); err != nil { + r := bytes.NewReader(v) + typeMap, err := lnwire.DecodeRecords( + r, lnwire.ProduceRecordsSorted(recordProducers...)..., + ) + if err != nil { return nil, err } - var hops []*mcHop - for i := uint32(0); i < numHops; i++ { - hop, err := deserializeHop(r) - if err != nil { - return nil, err - } - hops = append(hops, hop) + if _, ok := typeMap[result.failure.TlvType()]; ok { + result.failure = tlv.SomeRecordT(failure) } - rt.hops = hops - return &rt, nil + return &result, nil } -// deserializeHop deserializes the mcHop from the given io.Reader. -func deserializeHop(r io.Reader) (*mcHop, error) { - var h mcHop +// serializeRoute serializes a mcRoute and writes the resulting bytes to the +// given io.Writer. +func serializeRoute(w io.Writer, r *mcRoute) error { + records := lnwire.ProduceRecordsSorted( + &r.sourcePubKey, + &r.totalAmount, + &r.hops, + ) - var pub []byte - if err := channeldb.ReadElements(r, &pub); err != nil { - return nil, err - } - copy(h.pubKeyBytes[:], pub) + return lnwire.EncodeRecordsTo(w, records) +} + +// deserializeRoute deserializes the mcRoute from the given io.Reader. +func deserializeRoute(r io.Reader) (*mcRoute, error) { + var rt mcRoute + records := lnwire.ProduceRecordsSorted( + &rt.sourcePubKey, + &rt.totalAmount, + &rt.hops, + ) - if err := channeldb.ReadElements(r, - &h.channelID, &h.amtToFwd, &h.hasBlindingPoint, - &h.hasCustomRecords, - ); err != nil { + _, err := lnwire.DecodeRecords(r, records...) + if err != nil { return nil, err } - return &h, nil + return &rt, nil } -// serializeRoute serializes a mcRoute and writes the resulting bytes to the -// given io.Writer. -func serializeRoute(w io.Writer, r *mcRoute) error { - err := channeldb.WriteElements(w, r.totalAmount, r.sourcePubKey[:]) +// deserializeHop deserializes the mcHop from the given io.Reader. +func deserializeHop(r io.Reader) (*mcHop, error) { + var ( + h mcHop + blinding = tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean]() + custom = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]() + ) + records := lnwire.ProduceRecordsSorted( + &h.channelID, + &h.pubKeyBytes, + &h.amtToFwd, + &blinding, + &custom, + ) + + typeMap, err := lnwire.DecodeRecords(r, records...) if err != nil { - return err + return nil, err } - if err := channeldb.WriteElements(w, uint32(len(r.hops))); err != nil { - return err + if _, ok := typeMap[h.hasBlindingPoint.TlvType()]; ok { + h.hasBlindingPoint = tlv.SomeRecordT(blinding) } - for _, h := range r.hops { - if err := serializeHop(w, h); err != nil { - return err - } + if _, ok := typeMap[h.hasCustomRecords.TlvType()]; ok { + h.hasCustomRecords = tlv.SomeRecordT(custom) } - return nil + return &h, nil } // serializeHop serializes a mcHop and writes the resulting bytes to the given // io.Writer. func serializeHop(w io.Writer, h *mcHop) error { - return channeldb.WriteElements(w, - h.pubKeyBytes[:], - h.channelID, - h.amtToFwd, - h.hasBlindingPoint, - h.hasCustomRecords, - ) -} - -// deserializeResult deserializes a payment result. -func deserializeResult(k, v []byte) (*paymentResult, error) { - // Parse payment id. - result := paymentResult{ - id: byteOrder.Uint64(k[8:]), + recordProducers := []tlv.RecordProducer{ + &h.channelID, + &h.pubKeyBytes, + &h.amtToFwd, } - r := bytes.NewReader(v) - - // Read timestamps, success status and failure source index. - var ( - timeFwd, timeReply uint64 - dbFailureSourceIdx int32 - ) + h.hasBlindingPoint.WhenSome(func( + hasBlinding tlv.RecordT[tlv.TlvType3, lnwire.TrueBoolean]) { - err := channeldb.ReadElements( - r, &timeFwd, &timeReply, &result.success, &dbFailureSourceIdx, - ) - if err != nil { - return nil, err - } - - // Convert time stamps to local time zone for consistent logging. - result.timeFwd = time.Unix(0, int64(timeFwd)).Local() - result.timeReply = time.Unix(0, int64(timeReply)).Local() + recordProducers = append(recordProducers, &hasBlinding) + }) - // Convert from unknown index magic number to nil value. - if dbFailureSourceIdx != unknownFailureSourceIdx { - failureSourceIdx := int(dbFailureSourceIdx) - result.failureSourceIdx = &failureSourceIdx - } + h.hasCustomRecords.WhenSome(func( + hasCustom tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) { - // Read route. - route, err := deserializeRoute(r) - if err != nil { - return nil, err - } - result.route = route + recordProducers = append(recordProducers, &hasCustom) + }) - // Read failure. - failureBytes, err := wire.ReadVarBytes( - r, 0, math.MaxUint16, "failure", + return lnwire.EncodeRecordsTo( + w, lnwire.ProduceRecordsSorted(recordProducers...), ) - if err != nil { - return nil, err - } - if len(failureBytes) > 0 { - result.failure, err = lnwire.DecodeFailureMessage( - bytes.NewReader(failureBytes), 0, - ) - if err != nil { - return nil, err - } - } - - return &result, nil } // AddResult adds a new result to the db. @@ -580,9 +528,70 @@ func getResultKey(rp *paymentResult) []byte { // key. This allows importing mission control data from an external // source without key collisions and keeps the records sorted // chronologically. - byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) + byteOrder.PutUint64(keyBytes[:], rp.timeReply.Val) byteOrder.PutUint64(keyBytes[8:], rp.id) - copy(keyBytes[16:], rp.route.sourcePubKey[:]) + copy(keyBytes[16:], rp.route.Val.sourcePubKey.Val[:]) return keyBytes[:] } + +// failureMessage wraps the lnwire.FailureMessage interface such that we can +// apply a Record method and use the failureMessage in a TLV encoded type. +type failureMessage struct { + lnwire.FailureMessage +} + +// Record returns a TLV record that can be used to encode/decode a list of +// failureMessage to/from a TLV stream. +func (r *failureMessage) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeFailureMessage(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodeFailureMessage, decodeFailureMessage, + ) +} + +func encodeFailureMessage(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*failureMessage); ok { + var b bytes.Buffer + err := lnwire.EncodeFailureMessage(&b, v.FailureMessage, 0) + if err != nil { + return err + } + + _, err = w.Write(b.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "routing.failureMessage") +} + +func decodeFailureMessage(r io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if v, ok := val.(*failureMessage); ok { + msg, err := lnwire.DecodeFailureMessage(r, 0) + if err != nil { + return err + } + + *v = failureMessage{ + FailureMessage: msg, + } + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.failureMessage", l, l) +} diff --git a/routing/missioncontrol_store_test.go b/routing/missioncontrol_store_test.go index 2dbfc11214..f1788a96a5 100644 --- a/routing/missioncontrol_store_test.go +++ b/routing/missioncontrol_store_test.go @@ -11,27 +11,25 @@ import ( "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) const testMaxRecords = 2 -var ( - // mcStoreTestRoute is a test route for the mission control store tests. - mcStoreTestRoute = mcRoute{ - totalAmount: lnwire.MilliSatoshi(5), - sourcePubKey: route.Vertex{1}, - hops: []*mcHop{ - { - pubKeyBytes: route.Vertex{2}, - channelID: 4, - amtToFwd: lnwire.MilliSatoshi(7), - hasCustomRecords: true, - hasBlindingPoint: false, - }, +// mcStoreTestRoute is a test route for the mission control store tests. +var mcStoreTestRoute = extractMCRoute(&route.Route{ + TotalAmount: lnwire.MilliSatoshi(5), + SourcePubKey: route.Vertex{1}, + Hops: []*route.Hop{ + { + PubKeyBytes: route.Vertex{2}, + ChannelID: 4, + AmtToForward: lnwire.MilliSatoshi(7), + CustomRecords: make(map[uint64][]byte), }, - } -) + }, +}) // mcStoreTestHarness is the harness for a MissonControlStore test. type mcStoreTestHarness struct { @@ -84,28 +82,31 @@ func TestMissionControlStore(t *testing.T) { failureSourceIdx := 1 - result1 := paymentResult{ - route: &mcStoreTestRoute, - failure: lnwire.NewFailIncorrectDetails(100, 1000), - failureSourceIdx: &failureSourceIdx, - id: 99, - timeReply: testTime, - timeFwd: testTime.Add(-time.Minute), - } + result1 := newPaymentResult( + 99, mcStoreTestRoute, testTime, testTime, + newPaymentFailure( + &failureSourceIdx, + lnwire.NewFailIncorrectDetails(100, 1000), + ), + ) - result2 := result1 - result2.timeReply = result1.timeReply.Add(time.Hour) - result2.timeFwd = result1.timeReply.Add(time.Hour) - result2.id = 2 + result2 := newPaymentResult( + 2, mcStoreTestRoute, testTime.Add(time.Hour), + testTime.Add(time.Hour), + newPaymentFailure( + &failureSourceIdx, + lnwire.NewFailIncorrectDetails(100, 1000), + ), + ) // Store result. - store.AddResult(&result2) + store.AddResult(result2) // Store again to test idempotency. - store.AddResult(&result2) + store.AddResult(result2) // Store second result which has an earlier timestamp. - store.AddResult(&result1) + store.AddResult(result1) require.NoError(t, store.storeResults()) results, err = store.fetchAll() @@ -113,8 +114,8 @@ func TestMissionControlStore(t *testing.T) { require.Len(t, results, 2) // Check that results are stored in chronological order. - require.Equal(t, &result1, results[0]) - require.Equal(t, &result2, results[1]) + require.Equal(t, result1, results[0]) + require.Equal(t, result2, results[1]) // Recreate store to test pruning. store, err = newMissionControlStore( @@ -124,12 +125,20 @@ func TestMissionControlStore(t *testing.T) { // Add a newer result which failed due to mpp timeout. result3 := result1 - result3.timeReply = result1.timeReply.Add(2 * time.Hour) - result3.timeFwd = result1.timeReply.Add(2 * time.Hour) + result3.timeReply = tlv.NewPrimitiveRecord[tlv.TlvType1]( + uint64(testTime.Add(2 * time.Hour).UnixNano()), + ) + result3.timeFwd = tlv.NewPrimitiveRecord[tlv.TlvType0]( + uint64(testTime.Add(2 * time.Hour).UnixNano()), + ) result3.id = 3 - result3.failure = &lnwire.FailMPPTimeout{} + result3.failure = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3](*newPaymentFailure( + &failureSourceIdx, &lnwire.FailMPPTimeout{}, + )), + ) - store.AddResult(&result3) + store.AddResult(result3) require.NoError(t, store.storeResults()) // Check that results are pruned. @@ -137,8 +146,25 @@ func TestMissionControlStore(t *testing.T) { require.NoError(t, err) require.Len(t, results, 2) - require.Equal(t, &result2, results[0]) - require.Equal(t, &result3, results[1]) + require.Equal(t, result2, results[0]) + require.Equal(t, result3, results[1]) + + // Also demonstrate the persistence of a success result. + result4 := newPaymentResult( + 5, mcStoreTestRoute, testTime.Add(3*time.Hour), + testTime.Add(3*time.Hour), nil, + ) + store.AddResult(result4) + require.NoError(t, store.storeResults()) + + // We should still only have 2 results. + results, err = store.fetchAll() + require.NoError(t, err) + require.Len(t, results, 2) + + // The two latest results should have been returned. + require.Equal(t, result3, results[0]) + require.Equal(t, result4, results[1]) } // TestMissionControlStoreFlushing asserts the periodic flushing of the store @@ -156,14 +182,11 @@ func TestMissionControlStoreFlushing(t *testing.T) { ) nextResult := func() *paymentResult { lastID += 1 - return &paymentResult{ - route: &mcStoreTestRoute, - failure: failureDetails, - failureSourceIdx: &failureSourceIdx, - id: lastID, - timeReply: testTime, - timeFwd: testTime.Add(-time.Minute), - } + return newPaymentResult( + lastID, mcStoreTestRoute, testTime.Add(-time.Hour), + testTime, + newPaymentFailure(&failureSourceIdx, failureDetails), + ) } // Helper to assert the number of results is correct. @@ -260,14 +283,14 @@ func BenchmarkMissionControlStoreFlushing(b *testing.B) { var lastID uint64 for i := 0; i < testMaxRecords; i++ { lastID++ - result := &paymentResult{ - route: &mcStoreTestRoute, - failure: failureDetails, - failureSourceIdx: &failureSourceIdx, - id: lastID, - timeReply: testTime, - timeFwd: testTimeFwd, - } + result := newPaymentResult( + lastID, mcStoreTestRoute, testTimeFwd, + testTime, + newPaymentFailure( + &failureSourceIdx, + failureDetails, + ), + ) store.AddResult(result) } @@ -278,13 +301,14 @@ func BenchmarkMissionControlStoreFlushing(b *testing.B) { // Create the additional results. results := make([]*paymentResult, tc) for i := 0; i < len(results); i++ { - results[i] = &paymentResult{ - route: &mcStoreTestRoute, - failure: failureDetails, - failureSourceIdx: &failureSourceIdx, - timeReply: testTime, - timeFwd: testTimeFwd, - } + results[i] = newPaymentResult( + 0, mcStoreTestRoute, testTimeFwd, + testTime, + newPaymentFailure( + &failureSourceIdx, + failureDetails, + ), + ) } // Run the actual benchmark. diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go index 4da4134214..089213d65e 100644 --- a/routing/result_interpretation.go +++ b/routing/result_interpretation.go @@ -1,11 +1,15 @@ package routing import ( + "bytes" "fmt" + "io" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) // Instantiate variables to allow taking a reference from the failure reason. @@ -76,63 +80,73 @@ type interpretedResult struct { // interpretResult interprets a payment outcome and returns an object that // contains information required to update mission control. -func interpretResult(rt *mcRoute, success bool, failureSrcIdx *int, - failure lnwire.FailureMessage) *interpretedResult { +func interpretResult(rt *mcRoute, + failure fn.Option[paymentFailure]) *interpretedResult { i := &interpretedResult{ pairResults: make(map[DirectedNodePair]pairResult), } - if success { + return fn.ElimOption(failure, func() *interpretedResult { i.processSuccess(rt) - } else { - i.processFail(rt, failureSrcIdx, failure) - } - return i + + return i + }, func(info paymentFailure) *interpretedResult { + i.processFail(rt, info) + + return i + }) } // processSuccess processes a successful payment attempt. func (i *interpretedResult) processSuccess(route *mcRoute) { // For successes, all nodes must have acted in the right way. Therefore // we mark all of them with a success result. - i.successPairRange(route, 0, len(route.hops)-1) + i.successPairRange(route, 0, len(route.hops.Val)-1) } // processFail processes a failed payment attempt. -func (i *interpretedResult) processFail(rt *mcRoute, errSourceIdx *int, - failure lnwire.FailureMessage) { - - if errSourceIdx == nil { +func (i *interpretedResult) processFail(rt *mcRoute, failure paymentFailure) { + if failure.info.IsNone() { i.processPaymentOutcomeUnknown(rt) return } + var ( + idx int + failMsg lnwire.FailureMessage + ) + + failure.info.WhenSome( + func(r tlv.RecordT[tlv.TlvType0, paymentFailureInfo]) { + idx = int(r.Val.sourceIdx.Val) + failMsg = r.Val.msg.Val.FailureMessage + }, + ) + // If the payment was to a blinded route and we received an error from // after the introduction point, handle this error separately - there // has been a protocol violation from the introduction node. This // penalty applies regardless of the error code that is returned. introIdx, isBlinded := introductionPointIndex(rt) - if isBlinded && introIdx < *errSourceIdx { - i.processPaymentOutcomeBadIntro(rt, introIdx, *errSourceIdx) + if isBlinded && introIdx < idx { + i.processPaymentOutcomeBadIntro(rt, introIdx, idx) return } - switch *errSourceIdx { - + switch idx { // We are the source of the failure. case 0: - i.processPaymentOutcomeSelf(rt, failure) + i.processPaymentOutcomeSelf(rt, failMsg) // A failure from the final hop was received. - case len(rt.hops): - i.processPaymentOutcomeFinal(rt, failure) + case len(rt.hops.Val): + i.processPaymentOutcomeFinal(rt, failMsg) // An intermediate hop failed. Interpret the outcome, update reputation // and try again. default: - i.processPaymentOutcomeIntermediate( - rt, *errSourceIdx, failure, - ) + i.processPaymentOutcomeIntermediate(rt, idx, failMsg) } } @@ -158,7 +172,7 @@ func (i *interpretedResult) processPaymentOutcomeBadIntro(route *mcRoute, // a final failure reason because the recipient can't process the // payment (independent of the introduction failing to convert the // error, we can't complete the payment if the last hop fails). - if errSourceIdx == len(route.hops) { + if errSourceIdx == len(route.hops.Val) { i.finalFailureReason = &reasonError } } @@ -178,7 +192,7 @@ func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute, i.failNode(rt, 1) // If this was a payment to a direct peer, we can stop trying. - if len(rt.hops) == 1 { + if len(rt.hops.Val) == 1 { i.finalFailureReason = &reasonError } @@ -188,7 +202,7 @@ func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute, // available in the link has been updated. default: log.Warnf("Routing failure for local channel %v occurred", - rt.hops[0].channelID) + rt.hops.Val[0].channelID) } } @@ -196,7 +210,7 @@ func (i *interpretedResult) processPaymentOutcomeSelf(rt *mcRoute, func (i *interpretedResult) processPaymentOutcomeFinal(route *mcRoute, failure lnwire.FailureMessage) { - n := len(route.hops) + n := len(route.hops.Val) failNode := func() { i.failNode(route, n) @@ -396,8 +410,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute, // Set the node pair for which a channel update may be out of // date. The second chance logic uses the policyFailure field. i.policyFailure = &DirectedNodePair{ - From: route.hops[errorSourceIdx-1].pubKeyBytes, - To: route.hops[errorSourceIdx].pubKeyBytes, + From: route.hops.Val[errorSourceIdx-1].pubKeyBytes.Val, + To: route.hops.Val[errorSourceIdx].pubKeyBytes.Val, } reportOutgoing() @@ -425,8 +439,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute, // Set the node pair for which a channel update may be out of // date. The second chance logic uses the policyFailure field. i.policyFailure = &DirectedNodePair{ - From: route.hops[errorSourceIdx-1].pubKeyBytes, - To: route.hops[errorSourceIdx].pubKeyBytes, + From: route.hops.Val[errorSourceIdx-1].pubKeyBytes.Val, + To: route.hops.Val[errorSourceIdx].pubKeyBytes.Val, } // We report incoming channel. If a second pair is granted in @@ -500,14 +514,14 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute, // Note that if LND is extended to support multiple blinded // routes, this will terminate the payment without re-trying // the other routes. - if introIdx == len(route.hops)-1 { + if introIdx == len(route.hops.Val)-1 { i.finalFailureReason = &reasonError } else { // If there are other hops between the recipient and // introduction node, then we just penalize the last // hop in the blinded route to minimize the storage of // results for ephemeral keys. - i.failPairBalance(route, len(route.hops)-1) + i.failPairBalance(route, len(route.hops.Val)-1) } // In all other cases, we penalize the reporting node. These are all @@ -522,8 +536,8 @@ func (i *interpretedResult) processPaymentOutcomeIntermediate(route *mcRoute, // (i.e., that we consider our own node to be at index zero). A boolean is // returned to indicate whether the route contains a blinded portion at all. func introductionPointIndex(route *mcRoute) (int, bool) { - for i, hop := range route.hops { - if hop.hasBlindingPoint { + for i, hop := range route.hops.Val { + if hop.hasBlindingPoint.IsSome() { return i + 1, true } } @@ -534,7 +548,7 @@ func introductionPointIndex(route *mcRoute) (int, bool) { // processPaymentOutcomeUnknown processes a payment outcome for which no failure // message or source is available. func (i *interpretedResult) processPaymentOutcomeUnknown(route *mcRoute) { - n := len(route.hops) + n := len(route.hops.Val) // If this is a direct payment, the destination must be at fault. if n == 1 { @@ -551,52 +565,204 @@ func (i *interpretedResult) processPaymentOutcomeUnknown(route *mcRoute) { // extractMCRoute extracts the fields required by MC from the Route struct to // create the more minimal mcRoute struct. -func extractMCRoute(route *route.Route) *mcRoute { +func extractMCRoute(r *route.Route) *mcRoute { return &mcRoute{ - sourcePubKey: route.SourcePubKey, - totalAmount: route.TotalAmount, - hops: extractMCHops(route.Hops), + sourcePubKey: tlv.NewRecordT[tlv.TlvType0](r.SourcePubKey), + totalAmount: tlv.NewRecordT[tlv.TlvType1](r.TotalAmount), + hops: tlv.NewRecordT[tlv.TlvType2]( + extractMCHops(r.Hops), + ), } } // extractMCHops extracts the Hop fields that MC actually uses from a slice of // Hops. -func extractMCHops(hops []*route.Hop) []*mcHop { - mcHops := make([]*mcHop, len(hops)) - for i, hop := range hops { - mcHops[i] = extractMCHop(hop) - } - - return mcHops +func extractMCHops(hops []*route.Hop) mcHops { + return fn.Map(extractMCHop, hops) } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. func extractMCHop(hop *route.Hop) *mcHop { - return &mcHop{ - channelID: hop.ChannelID, - pubKeyBytes: hop.PubKeyBytes, - amtToFwd: hop.AmtToForward, - hasBlindingPoint: hop.BlindingPoint != nil, - hasCustomRecords: len(hop.CustomRecords) > 0, + h := mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0]( + hop.ChannelID, + ), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1](hop.PubKeyBytes), + amtToFwd: tlv.NewRecordT[tlv.TlvType2](hop.AmtToForward), } + + if hop.BlindingPoint != nil { + h.hasBlindingPoint = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3](lnwire.TrueBoolean{}), + ) + } + + if hop.CustomRecords != nil { + h.hasCustomRecords = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType4](lnwire.TrueBoolean{}), + ) + } + + return &h } // mcRoute holds the bare minimum info about a payment attempt route that MC // requires. type mcRoute struct { - sourcePubKey route.Vertex - totalAmount lnwire.MilliSatoshi - hops []*mcHop + sourcePubKey tlv.RecordT[tlv.TlvType0, route.Vertex] + totalAmount tlv.RecordT[tlv.TlvType1, lnwire.MilliSatoshi] + hops tlv.RecordT[tlv.TlvType2, mcHops] +} + +// Record returns a TLV record that can be used to encode/decode an mcRoute +// to/from a TLV stream. +func (r *mcRoute) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeMCRoute(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodeMCRoute, decodeMCRoute, + ) +} + +func encodeMCRoute(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*mcRoute); ok { + return serializeRoute(w, v) + } + + return tlv.NewTypeForEncodingErr(val, "routing.mcRoute") +} + +func decodeMCRoute(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if v, ok := val.(*mcRoute); ok { + route, err := deserializeRoute(io.LimitReader(r, int64(l))) + if err != nil { + return err + } + + *v = *route + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.mcRoute", l, l) +} + +// mcHops is a list of mcHop records. +type mcHops []*mcHop + +// Record returns a TLV record that can be used to encode/decode a list of +// mcHop to/from a TLV stream. +func (h *mcHops) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeMCHops(&b, h, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, h, recordSize, encodeMCHops, decodeMCHops, + ) +} + +func encodeMCHops(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*mcHops); ok { + // Encode the number of hops as a var int. + if err := tlv.WriteVarInt(w, uint64(len(*v)), buf); err != nil { + return err + } + + // With that written out, we'll now encode the entries + // themselves as a sub-TLV record, which includes its _own_ + // inner length prefix. + for _, hop := range *v { + var hopBytes bytes.Buffer + if err := serializeHop(&hopBytes, hop); err != nil { + return err + } + + // We encode the record with a varint length followed by + // the _raw_ TLV bytes. + tlvLen := uint64(len(hopBytes.Bytes())) + if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil { + return err + } + + if _, err := w.Write(hopBytes.Bytes()); err != nil { + return err + } + } + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "routing.mcHops") +} + +func decodeMCHops(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*mcHops); ok { + // First, we'll decode the varint that encodes how many hops + // are encoded in the stream. + numHops, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Now that we know how many records we'll need to read, we can + // iterate and read them all out in series. + for i := uint64(0); i < numHops; i++ { + // Read out the varint that encodes the size of this + // inner TLV record. + hopSize, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Using this information, we'll create a new limited + // reader that'll return an EOF once the end has been + // reached so the stream stops consuming bytes. + innerTlvReader := &io.LimitedReader{ + R: r, + N: int64(hopSize), + } + + hop, err := deserializeHop(innerTlvReader) + if err != nil { + return err + } + + *v = append(*v, hop) + } + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.mcHops", l, l) } // mcHop holds the bare minimum info about a payment attempt route hop that MC // requires. type mcHop struct { - channelID uint64 - pubKeyBytes route.Vertex - amtToFwd lnwire.MilliSatoshi - hasBlindingPoint bool - hasCustomRecords bool + channelID tlv.RecordT[tlv.TlvType0, uint64] + pubKeyBytes tlv.RecordT[tlv.TlvType1, route.Vertex] + amtToFwd tlv.RecordT[tlv.TlvType2, lnwire.MilliSatoshi] + hasBlindingPoint tlv.OptionalRecordT[tlv.TlvType3, lnwire.TrueBoolean] + hasCustomRecords tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean] } // failNode marks the node indicated by idx in the route as failed. It also @@ -604,7 +770,7 @@ type mcHop struct { // intentionally panics when the self node is failed. func (i *interpretedResult) failNode(rt *mcRoute, idx int) { // Mark the node as failing. - i.nodeFailure = &rt.hops[idx-1].pubKeyBytes + i.nodeFailure = &rt.hops.Val[idx-1].pubKeyBytes.Val // Mark the incoming connection as failed for the node. We intent to // penalize as much as we can for a node level failure, including future @@ -620,7 +786,7 @@ func (i *interpretedResult) failNode(rt *mcRoute, idx int) { // If not the ultimate node, mark the outgoing connection as failed for // the node. - if idx < len(rt.hops) { + if idx < len(rt.hops.Val) { outgoingChannelIdx := idx outPair, _ := getPair(rt, outgoingChannelIdx) i.pairResults[outPair] = failPairResult(0) @@ -667,18 +833,18 @@ func (i *interpretedResult) successPairRange(rt *mcRoute, fromIdx, toIdx int) { func getPair(rt *mcRoute, channelIdx int) (DirectedNodePair, lnwire.MilliSatoshi) { - nodeTo := rt.hops[channelIdx].pubKeyBytes + nodeTo := rt.hops.Val[channelIdx].pubKeyBytes.Val var ( nodeFrom route.Vertex amt lnwire.MilliSatoshi ) if channelIdx == 0 { - nodeFrom = rt.sourcePubKey - amt = rt.totalAmount + nodeFrom = rt.sourcePubKey.Val + amt = rt.totalAmount.Val } else { - nodeFrom = rt.hops[channelIdx-1].pubKeyBytes - amt = rt.hops[channelIdx-1].amtToFwd + nodeFrom = rt.hops.Val[channelIdx-1].pubKeyBytes.Val + amt = rt.hops.Val[channelIdx-1].amtToFwd.Val } pair := NewDirectedNodePair(nodeFrom, nodeTo) diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index 14506e3a14..b213eb1835 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -4,7 +4,9 @@ import ( "reflect" "testing" + "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -14,110 +16,170 @@ var ( {1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4}, } - routeOneHop = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ - {pubKeyBytes: hops[1], amtToFwd: 99}, + routeOneHop = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ + { + PubKeyBytes: hops[1], + AmtToForward: 99, + }, }, - } + }) - routeTwoHop = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ - {pubKeyBytes: hops[1], amtToFwd: 99}, - {pubKeyBytes: hops[2], amtToFwd: 97}, + routeTwoHop = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ + { + PubKeyBytes: hops[1], + AmtToForward: 99, + }, + { + PubKeyBytes: hops[2], + AmtToForward: 97, + }, }, - } + }) - routeThreeHop = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ - {pubKeyBytes: hops[1], amtToFwd: 99}, - {pubKeyBytes: hops[2], amtToFwd: 97}, - {pubKeyBytes: hops[3], amtToFwd: 94}, + routeThreeHop = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ + { + PubKeyBytes: hops[1], + AmtToForward: 99, + }, + { + PubKeyBytes: hops[2], + AmtToForward: 97, + }, + { + PubKeyBytes: hops[3], + AmtToForward: 94, + }, }, - } + }) - routeFourHop = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ - {pubKeyBytes: hops[1], amtToFwd: 99}, - {pubKeyBytes: hops[2], amtToFwd: 97}, - {pubKeyBytes: hops[3], amtToFwd: 94}, - {pubKeyBytes: hops[4], amtToFwd: 90}, + routeFourHop = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ + { + PubKeyBytes: hops[1], + AmtToForward: 99, + }, + { + PubKeyBytes: hops[2], + AmtToForward: 97, + }, + { + PubKeyBytes: hops[3], + AmtToForward: 94, + }, + { + PubKeyBytes: hops[4], + AmtToForward: 90, + }, }, - } + }) // blindedMultiHop is a blinded path where there are cleartext hops // before the introduction node, and an intermediate blinded hop before // the recipient after it. - blindedMultiHop = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ - {pubKeyBytes: hops[1], amtToFwd: 99}, + blindedMultiHop = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ { - pubKeyBytes: hops[2], - amtToFwd: 95, - hasBlindingPoint: true, + PubKeyBytes: hops[1], + AmtToForward: 99, + }, + { + PubKeyBytes: hops[2], + AmtToForward: 95, + BlindingPoint: genTestPubKey(), + }, + { + PubKeyBytes: hops[3], + AmtToForward: 88, + }, + { + PubKeyBytes: hops[4], + AmtToForward: 77, }, - {pubKeyBytes: hops[3], amtToFwd: 88}, - {pubKeyBytes: hops[4], amtToFwd: 77}, }, - } + }) // blindedSingleHop is a blinded path with a single blinded hop after // the introduction node. - blindedSingleHop = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ - {pubKeyBytes: hops[1], amtToFwd: 99}, + blindedSingleHop = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ { - pubKeyBytes: hops[2], - amtToFwd: 95, - hasBlindingPoint: true, + PubKeyBytes: hops[1], + AmtToForward: 99, + }, + { + PubKeyBytes: hops[2], + AmtToForward: 95, + BlindingPoint: genTestPubKey(), + }, + { + PubKeyBytes: hops[3], + AmtToForward: 88, }, - {pubKeyBytes: hops[3], amtToFwd: 88}, }, - } + }) // blindedMultiToIntroduction is a blinded path which goes directly // to the introduction node, with multiple blinded hops after it. - blindedMultiToIntroduction = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ + blindedMultiToIntroduction = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ + { + PubKeyBytes: hops[1], + AmtToForward: 90, + BlindingPoint: genTestPubKey(), + }, + { + PubKeyBytes: hops[2], + AmtToForward: 75, + }, { - pubKeyBytes: hops[1], - amtToFwd: 90, - hasBlindingPoint: true, + PubKeyBytes: hops[3], + AmtToForward: 58, }, - {pubKeyBytes: hops[2], amtToFwd: 75}, - {pubKeyBytes: hops[3], amtToFwd: 58}, }, - } + }) // blindedIntroReceiver is a blinded path where the introduction node // is the recipient. - blindedIntroReceiver = mcRoute{ - sourcePubKey: hops[0], - totalAmount: 100, - hops: []*mcHop{ - {pubKeyBytes: hops[1], amtToFwd: 95}, + blindedIntroReceiver = extractMCRoute(&route.Route{ + SourcePubKey: hops[0], + TotalAmount: 100, + Hops: []*route.Hop{ + { + PubKeyBytes: hops[1], + AmtToForward: 95, + }, { - pubKeyBytes: hops[2], - amtToFwd: 90, - hasBlindingPoint: true, + PubKeyBytes: hops[2], + AmtToForward: 90, + BlindingPoint: genTestPubKey(), }, }, - } + }) ) +func genTestPubKey() *btcec.PublicKey { + key, _ := btcec.NewPrivateKey() + + return key.PubKey() +} + func getTestPair(from, to int) DirectedNodePair { return NewDirectedNodePair(hops[from], hops[to]) } @@ -142,7 +204,7 @@ var resultTestCases = []resultTestCase{ // interpreted. { name: "fail", - route: &routeTwoHop, + route: routeTwoHop, failureSrcIdx: 1, failure: lnwire.NewTemporaryChannelFailure(nil), @@ -157,7 +219,7 @@ var resultTestCases = []resultTestCase{ // Tests that an expiry too soon failure result is properly interpreted. { name: "fail expiry too soon", - route: &routeFourHop, + route: routeFourHop, failureSrcIdx: 3, failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}), @@ -177,7 +239,7 @@ var resultTestCases = []resultTestCase{ // failure, but mark all pairs along the route as successful. { name: "fail incorrect details", - route: &routeTwoHop, + route: routeTwoHop, failureSrcIdx: 2, failure: lnwire.NewFailIncorrectDetails(97, 0), @@ -193,7 +255,7 @@ var resultTestCases = []resultTestCase{ // Tests a successful direct payment. { name: "success direct", - route: &routeOneHop, + route: routeOneHop, success: true, expectedResult: &interpretedResult{ @@ -206,7 +268,7 @@ var resultTestCases = []resultTestCase{ // Tests a successful two hop payment. { name: "success", - route: &routeTwoHop, + route: routeTwoHop, success: true, expectedResult: &interpretedResult{ @@ -220,7 +282,7 @@ var resultTestCases = []resultTestCase{ // Tests a malformed htlc from a direct peer. { name: "fail malformed htlc from direct peer", - route: &routeTwoHop, + route: routeTwoHop, failureSrcIdx: 0, failure: lnwire.NewInvalidOnionKey(nil), @@ -239,7 +301,7 @@ var resultTestCases = []resultTestCase{ // destination. { name: "fail malformed htlc from direct final peer", - route: &routeOneHop, + route: routeOneHop, failureSrcIdx: 0, failure: lnwire.NewInvalidOnionKey(nil), @@ -259,7 +321,7 @@ var resultTestCases = []resultTestCase{ // in a policy failure for the outgoing hop. { name: "fail fee insufficient intermediate", - route: &routeFourHop, + route: routeFourHop, failureSrcIdx: 2, failure: lnwire.NewFeeInsufficient( 0, lnwire.ChannelUpdate1{}, @@ -282,7 +344,7 @@ var resultTestCases = []resultTestCase{ // failure is terminal since the receiver can't process our onion. { name: "fail invalid onion payload final hop four", - route: &routeFourHop, + route: routeFourHop, failureSrcIdx: 4, failure: lnwire.NewInvalidOnionPayload(0, 0), @@ -311,7 +373,7 @@ var resultTestCases = []resultTestCase{ // Tests an invalid onion payload from a final hop on a three hop route. { name: "fail invalid onion payload final hop three", - route: &routeThreeHop, + route: routeThreeHop, failureSrcIdx: 3, failure: lnwire.NewInvalidOnionPayload(0, 0), @@ -338,7 +400,7 @@ var resultTestCases = []resultTestCase{ // can still try other paths. { name: "fail invalid onion payload intermediate", - route: &routeFourHop, + route: routeFourHop, failureSrcIdx: 3, failure: lnwire.NewInvalidOnionPayload(0, 0), @@ -366,7 +428,7 @@ var resultTestCases = []resultTestCase{ // since the remote node can't process our onion. { name: "fail invalid onion payload direct", - route: &routeOneHop, + route: routeOneHop, failureSrcIdx: 1, failure: lnwire.NewInvalidOnionPayload(0, 0), @@ -385,7 +447,7 @@ var resultTestCases = []resultTestCase{ // penalize mpp timeouts. { name: "one hop mpp timeout", - route: &routeOneHop, + route: routeOneHop, failureSrcIdx: 1, failure: &lnwire.FailMPPTimeout{}, @@ -402,7 +464,7 @@ var resultTestCases = []resultTestCase{ // temporary measure while we decide how to penalize mpp timeouts. { name: "two hop mpp timeout", - route: &routeTwoHop, + route: routeTwoHop, failureSrcIdx: 2, failure: &lnwire.FailMPPTimeout{}, @@ -419,7 +481,7 @@ var resultTestCases = []resultTestCase{ // disabled channel should be penalized for any amount. { name: "two hop channel disabled", - route: &routeTwoHop, + route: routeTwoHop, failureSrcIdx: 1, failure: &lnwire.FailChannelDisabled{}, @@ -437,7 +499,7 @@ var resultTestCases = []resultTestCase{ // has not followed the specification properly. { name: "error after introduction", - route: &blindedMultiToIntroduction, + route: blindedMultiToIntroduction, failureSrcIdx: 2, // Note that the failure code doesn't matter in this case - // all we're worried about is errors originating after the @@ -460,7 +522,7 @@ var resultTestCases = []resultTestCase{ // hop when we expected the introduction node to convert. { name: "final failure expected intro", - route: &blindedMultiHop, + route: blindedMultiHop, failureSrcIdx: 4, failure: &lnwire.FailInvalidBlinding{}, @@ -482,7 +544,7 @@ var resultTestCases = []resultTestCase{ // introduction point. { name: "blinded multi-hop introduction", - route: &blindedMultiHop, + route: blindedMultiHop, failureSrcIdx: 2, failure: &lnwire.FailInvalidBlinding{}, @@ -498,7 +560,7 @@ var resultTestCases = []resultTestCase{ // introduction point, which is a direct peer. { name: "blinded multi-hop introduction peer", - route: &blindedMultiToIntroduction, + route: blindedMultiToIntroduction, failureSrcIdx: 1, failure: &lnwire.FailInvalidBlinding{}, @@ -513,7 +575,7 @@ var resultTestCases = []resultTestCase{ // connected to the introduction node. { name: "blinded single hop introduction failure", - route: &blindedSingleHop, + route: blindedSingleHop, failureSrcIdx: 2, failure: &lnwire.FailInvalidBlinding{}, @@ -529,7 +591,7 @@ var resultTestCases = []resultTestCase{ // blinding error and is penalized for returning the wrong error. { name: "error before introduction", - route: &blindedMultiHop, + route: blindedMultiHop, failureSrcIdx: 1, failure: &lnwire.FailInvalidBlinding{}, @@ -549,7 +611,7 @@ var resultTestCases = []resultTestCase{ // successful hop before the incorrect error. { name: "intermediate unexpected blinding", - route: &routeThreeHop, + route: routeThreeHop, failureSrcIdx: 2, failure: &lnwire.FailInvalidBlinding{}, @@ -570,7 +632,7 @@ var resultTestCases = []resultTestCase{ // hops before the erring incoming link (the erring node if our peer). { name: "peer unexpected blinding", - route: &routeThreeHop, + route: routeThreeHop, failureSrcIdx: 1, failure: &lnwire.FailInvalidBlinding{}, @@ -588,7 +650,7 @@ var resultTestCases = []resultTestCase{ // A node in a non-blinded route returns a blinding related error. { name: "final node unexpected blinding", - route: &routeThreeHop, + route: routeThreeHop, failureSrcIdx: 3, failure: &lnwire.FailInvalidBlinding{}, @@ -606,7 +668,7 @@ var resultTestCases = []resultTestCase{ // Introduction node returns invalid blinding erroneously. { name: "final node intro blinding", - route: &blindedIntroReceiver, + route: blindedIntroReceiver, failureSrcIdx: 2, failure: &lnwire.FailInvalidBlinding{}, @@ -629,10 +691,15 @@ func TestResultInterpretation(t *testing.T) { for _, testCase := range resultTestCases { t.Run(testCase.name, func(t *testing.T) { - i := interpretResult( - testCase.route, testCase.success, - &testCase.failureSrcIdx, testCase.failure, - ) + var failure fn.Option[paymentFailure] + if !testCase.success { + failure = fn.Some(*newPaymentFailure( + &testCase.failureSrcIdx, + testCase.failure, + )) + } + + i := interpretResult(testCase.route, failure) expected := testCase.expectedResult diff --git a/routing/route/route.go b/routing/route/route.go index 9aa28759bc..2d62af4653 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -94,6 +94,32 @@ func (v Vertex) String() string { return fmt.Sprintf("%x", v[:]) } +// Record returns a TLV record that can be used to encode/decode a Vertex +// to/from a TLV stream. +func (v *Vertex) Record() tlv.Record { + return tlv.MakeStaticRecord( + 0, v, VertexSize, encodeVertex, decodeVertex, + ) +} + +func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*Vertex); ok { + _, err := w.Write(b[:]) + return err + } + + return tlv.NewTypeForEncodingErr(val, "Vertex") +} + +func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*Vertex); ok { + _, err := io.ReadFull(r, b[:]) + return err + } + + return tlv.NewTypeForDecodingErr(val, "Vertex", l, VertexSize) +} + // Hop represents an intermediate or final node of the route. This naming // is in line with the definition given in BOLT #4: Onion Routing Protocol. // The struct houses the channel along which this hop can be reached and