diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index cf5669a5057..b94bfd418d1 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -10,7 +10,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -45,12 +48,19 @@ type HTLCAttemptInfo struct { // in which the payment's PaymentHash in the PaymentCreationInfo should // be used. Hash *lntypes.Hash + + // onionBlob is the cached value for onion blob created from the sphinx + // construction. + onionBlob [lnwire.OnionPacketSize]byte + + // circuit is the cached value for sphinx circuit. + circuit *sphinx.Circuit } // NewHtlcAttempt creates a htlc attempt. func NewHtlcAttempt(attemptID uint64, sessionKey *btcec.PrivateKey, route route.Route, attemptTime time.Time, - hash *lntypes.Hash) *HTLCAttempt { + hash *lntypes.Hash) (*HTLCAttempt, error) { var scratch [btcec.PrivKeyBytesLen]byte copy(scratch[:], sessionKey.Serialize()) @@ -64,7 +74,11 @@ func NewHtlcAttempt(attemptID uint64, sessionKey *btcec.PrivateKey, Hash: hash, } - return &HTLCAttempt{HTLCAttemptInfo: info} + if err := info.attachOnionBlobAndCircuit(); err != nil { + return nil, err + } + + return &HTLCAttempt{HTLCAttemptInfo: info}, nil } // SessionKey returns the ephemeral key used for a htlc attempt. This function @@ -79,6 +93,45 @@ func (h *HTLCAttemptInfo) SessionKey() *btcec.PrivateKey { return h.cachedSessionKey } +// OnionBlob returns the onion blob created from the sphinx construction. +func (h *HTLCAttemptInfo) OnionBlob() ([lnwire.OnionPacketSize]byte, error) { + var zeroBytes [lnwire.OnionPacketSize]byte + if h.onionBlob == zeroBytes { + if err := h.attachOnionBlobAndCircuit(); err != nil { + return zeroBytes, err + } + } + + return h.onionBlob, nil +} + +// Circuit returns the sphinx circuit for this attempt. +func (h *HTLCAttemptInfo) Circuit() (*sphinx.Circuit, error) { + if h.circuit == nil { + if err := h.attachOnionBlobAndCircuit(); err != nil { + return nil, err + } + } + + return h.circuit, nil +} + +// attachOnionBlobAndCircuit creates a sphinx packet and caches the onion blob +// and circuit for this attempt. +func (h *HTLCAttemptInfo) attachOnionBlobAndCircuit() error { + onionBlob, circuit, err := generateSphinxPacket( + &h.Route, h.Hash[:], h.SessionKey(), + ) + if err != nil { + return err + } + + copy(h.onionBlob[:], onionBlob) + h.circuit = circuit + + return nil +} + // HTLCAttempt contains information about a specific HTLC attempt for a given // payment. It contains the HTLCAttemptInfo used to send the HTLC, as well // as a timestamp and any known outcome of the attempt. @@ -629,3 +682,69 @@ func serializeTime(w io.Writer, t time.Time) error { _, err := w.Write(scratch[:]) return err } + +// generateSphinxPacket generates then encodes a sphinx packet which encodes +// the onion route specified by the passed layer 3 route. The blob returned +// from this function can immediately be included within an HTLC add packet to +// be sent to the first hop within the route. +func generateSphinxPacket(rt *route.Route, paymentHash []byte, + sessionKey *btcec.PrivateKey) ([]byte, *sphinx.Circuit, error) { + + // Now that we know we have an actual route, we'll map the route into a + // sphinx payment path which includes per-hop payloads for each hop + // that give each node within the route the necessary information + // (fees, CLTV value, etc.) to properly forward the payment. + sphinxPath, err := rt.ToSphinxPath() + if err != nil { + return nil, nil, err + } + + log.Tracef("Constructed per-hop payloads for payment_hash=%x: %v", + paymentHash, lnutils.NewLogClosure(func() string { + path := make( + []sphinx.OnionHop, sphinxPath.TrueRouteLength(), + ) + for i := range path { + hopCopy := sphinxPath[i] + path[i] = hopCopy + } + + return spew.Sdump(path) + }), + ) + + // Next generate the onion routing packet which allows us to perform + // privacy preserving source routing across the network. + sphinxPacket, err := sphinx.NewOnionPacket( + sphinxPath, sessionKey, paymentHash, + sphinx.DeterministicPacketFiller, + ) + if err != nil { + return nil, nil, err + } + + // Finally, encode Sphinx packet using its wire representation to be + // included within the HTLC add packet. + var onionBlob bytes.Buffer + if err := sphinxPacket.Encode(&onionBlob); err != nil { + return nil, nil, err + } + + log.Tracef("Generated sphinx packet: %v", + lnutils.NewLogClosure(func() string { + // We make a copy of the ephemeral key and unset the + // internal curve here in order to keep the logs from + // getting noisy. + key := *sphinxPacket.EphemeralKey + packetCopy := *sphinxPacket + packetCopy.EphemeralKey = &key + + return spew.Sdump(packetCopy) + }), + ) + + return onionBlob.Bytes(), &sphinx.Circuit{ + SessionKey: sessionKey, + PaymentPath: sphinxPath.NodeKeys(), + }, nil +} diff --git a/channeldb/mp_payment_test.go b/channeldb/mp_payment_test.go index 51eda72bb0a..13e39871a98 100644 --- a/channeldb/mp_payment_test.go +++ b/channeldb/mp_payment_test.go @@ -5,12 +5,22 @@ import ( "fmt" "testing" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" ) +var ( + testHash = [32]byte{ + 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, + 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, + 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, + 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, + } +) + // TestLazySessionKeyDeserialize tests that we can read htlc attempt session // keys that were previously serialized as a private key as raw bytes. func TestLazySessionKeyDeserialize(t *testing.T) { @@ -578,3 +588,15 @@ func makeAttemptInfo(total, amtForwarded int) HTLCAttemptInfo { }, } } + +// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket +// function is able to gracefully handle being passed a nil set of hops for the +// route by the caller. +func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { + t.Parallel() + + sessionKey, _ := btcec.NewPrivateKey() + emptyRoute := &route.Route{} + _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) + require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) +} diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index fb965bb321c..9bc486a8311 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -28,7 +28,7 @@ func genPreimage() ([32]byte, error) { return preimage, nil } -func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, +func genInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo, lntypes.Preimage, error) { preimage, err := genPreimage() @@ -38,9 +38,14 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, } rhash := sha256.Sum256(preimage[:]) - attempt := NewHtlcAttempt( - 0, priv, *testRoute.Copy(), time.Time{}, nil, + var hash lntypes.Hash + copy(hash[:], rhash[:]) + + attempt, err := NewHtlcAttempt( + 0, priv, *testRoute.Copy(), time.Time{}, &hash, ) + require.NoError(t, err) + return &PaymentCreationInfo{ PaymentIdentifier: rhash, Value: testRoute.ReceiverAmt(), @@ -60,7 +65,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Sends base htlc message which initiate StatusInFlight. @@ -196,7 +201,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Sends base htlc message which initiate base status and move it to @@ -266,7 +271,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { pControl := NewPaymentControl(db) - info, _, preimg, err := genInfo() + info, _, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Attempt to complete the payment should fail. @@ -291,7 +296,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { pControl := NewPaymentControl(db) - info, _, _, err := genInfo() + info, _, _, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Calling Fail should return an error. @@ -346,7 +351,7 @@ func TestPaymentControlDeleteNonInFlight(t *testing.T) { var numSuccess, numInflight int for _, p := range payments { - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to generate htlc message: %v", err) } @@ -684,7 +689,7 @@ func TestPaymentControlMultiShard(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to generate htlc message: %v", err) } @@ -948,7 +953,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, _, err := genInfo() + info, attempt, _, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Init the payment. @@ -997,7 +1002,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { // Create and init a new payment. This time we'll check that we cannot // register an MPP attempt if we already registered a non-MPP one. - info, attempt, _, err = genInfo() + info, attempt, _, err = genInfo(t) require.NoError(t, err, "unable to generate htlc message") err = pControl.InitPayment(info.PaymentIdentifier, info) @@ -1271,7 +1276,7 @@ func createTestPayments(t *testing.T, p *PaymentControl, payments []*payment) { attemptID := uint64(0) for i := 0; i < len(payments); i++ { - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Set the payment id accordingly in the payments slice. diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 0c3753e6626..b2a0292a490 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -64,7 +64,6 @@ var ( TotalAmount: 1234567, SourcePubKey: vertex, Hops: []*route.Hop{ - testHop3, testHop2, testHop1, }, @@ -98,7 +97,7 @@ var ( } ) -func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) { +func makeFakeInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo) { var preimg lntypes.Preimage copy(preimg[:], rev[:]) @@ -113,9 +112,10 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) { PaymentRequest: []byte("test"), } - a := NewHtlcAttempt( + a, err := NewHtlcAttempt( 44, priv, testRoute, time.Unix(100, 0), &hash, ) + require.NoError(t, err) return c, &a.HTLCAttemptInfo } @@ -123,7 +123,7 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) { func TestSentPaymentSerialization(t *testing.T) { t.Parallel() - c, s := makeFakeInfo() + c, s := makeFakeInfo(t) var b bytes.Buffer require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") @@ -174,6 +174,9 @@ func TestSentPaymentSerialization(t *testing.T) { require.NoError(t, err, "deserialize") require.Equal(t, s.Route, newWireInfo.Route) + err = newWireInfo.attachOnionBlobAndCircuit() + require.NoError(t, err) + // Clear routes to allow DeepEqual to compare the remaining fields. newWireInfo.Route = route.Route{} s.Route = route.Route{} @@ -517,7 +520,7 @@ func TestQueryPayments(t *testing.T) { for i := 0; i < nonDuplicatePayments; i++ { // Generate a test payment. - info, _, preimg, err := genInfo() + info, _, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to create test "+ "payment: %v", err) @@ -618,7 +621,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { pControl := NewPaymentControl(db) // Generate a test payment which does not have duplicates. - noDuplicates, _, _, err := genInfo() + noDuplicates, _, _, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -632,7 +635,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { require.NoError(t, err) // Generate a test payment which we will add duplicates to. - hasDuplicates, _, preimg, err := genInfo() + hasDuplicates, _, preimg, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -783,7 +786,7 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) // Generate fake information for the duplicate payment. - info, _, _, err := genInfo() + info, _, _, err := genInfo(t) require.NoError(t, err) // Write the payment info to disk under the creation info key. This code diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 44df0bd3c3b..9f09ceade6e 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -68,6 +68,11 @@ fail to persist (and hence, propagate) node announcements containing address types (such as a DNS hostname) unknown to LND. +* [Fixed an edge case](https://github.com/lightningnetwork/lnd/pull/9150) where + the payment may become stuck if the invoice times out while the node + restarts, for details check [this + issue](https://github.com/lightningnetwork/lnd/issues/8975#issuecomment-2270528222). + # New Features * [Support](https://github.com/lightningnetwork/lnd/pull/8390) for diff --git a/itest/lnd_multi-hop_force_close_test.go b/itest/lnd_multi-hop_force_close_test.go index a7db9ea2888..da760a2895b 100644 --- a/itest/lnd_multi-hop_force_close_test.go +++ b/itest/lnd_multi-hop_force_close_test.go @@ -357,8 +357,11 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, // We'll create two random payment hashes unknown to carol, then send // each of them by manually specifying the HTLC details. carolPubKey := carol.PubKey[:] - dustPayHash := ht.Random32Bytes() - payHash := ht.Random32Bytes() + + preimageDust := ht.RandomPreimage() + preimage := ht.RandomPreimage() + dustPayHash := preimageDust.Hash() + payHash := preimage.Hash() // If this is a taproot channel, then we'll need to make some manual // route hints so Alice can actually find a route. @@ -370,7 +373,7 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, req := &routerrpc.SendPaymentRequest{ Dest: carolPubKey, Amt: int64(dustHtlcAmt), - PaymentHash: dustPayHash, + PaymentHash: dustPayHash[:], FinalCltvDelta: finalCltvDelta, FeeLimitMsat: noFeeLimitMsat, RouteHints: routeHints, @@ -380,7 +383,7 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, req = &routerrpc.SendPaymentRequest{ Dest: carolPubKey, Amt: int64(htlcAmt), - PaymentHash: payHash, + PaymentHash: payHash[:], FinalCltvDelta: finalCltvDelta, FeeLimitMsat: noFeeLimitMsat, RouteHints: routeHints, @@ -530,6 +533,25 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, // Once this transaction has been confirmed, Bob should detect that he // no longer has any pending channels. ht.AssertNumPendingForceClose(bob, 0) + + // Now that Bob has claimed his HTLCs, Alice should mark the two + // payments as failed. + // + // Alice will mark this payment as failed with no route as the only + // route she has is Alice->Bob->Carol. This won't be the case if she + // has a second route, as another attempt will be tried. + // + // TODO(yy): we should instead mark this payment as timed out if she has + // a second route to try this payment, which is the timeout set by Alice + // when sending the payment. + expectedReason := lnrpc.PaymentFailureReason_FAILURE_REASON_NO_ROUTE + p := ht.AssertPaymentFailureReason(alice, preimage, expectedReason) + require.Equal(ht, lnrpc.Failure_PERMANENT_CHANNEL_FAILURE, + p.Htlcs[0].Failure.Code) + + p = ht.AssertPaymentFailureReason(alice, preimageDust, expectedReason) + require.Equal(ht, lnrpc.Failure_PERMANENT_CHANNEL_FAILURE, + p.Htlcs[0].Failure.Code) } // testMultiHopReceiverPreimageClaimAnchor tests diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index 8130ac1679e..9af81a7cd2a 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -1600,8 +1600,11 @@ func (h *HarnessTest) AssertPaymentStatus(hn *node.HarnessNode, // AssertPaymentFailureReason asserts that the given node lists a payment with // the given preimage which has the expected failure reason. -func (h *HarnessTest) AssertPaymentFailureReason(hn *node.HarnessNode, - preimage lntypes.Preimage, reason lnrpc.PaymentFailureReason) { +func (h *HarnessTest) AssertPaymentFailureReason( + hn *node.HarnessNode, preimage lntypes.Preimage, + reason lnrpc.PaymentFailureReason) *lnrpc.Payment { + + var payment *lnrpc.Payment payHash := preimage.Hash() err := wait.NoError(func() error { @@ -1610,14 +1613,19 @@ func (h *HarnessTest) AssertPaymentFailureReason(hn *node.HarnessNode, return err } + payment = p + if reason == p.FailureReason { return nil } return fmt.Errorf("payment: %v failure reason not match, "+ - "want %s got %s", payHash, reason, p.Status) + "want %s(%d) got %s(%d)", payHash, reason, reason, + p.FailureReason, p.FailureReason) }, DefaultTimeout) require.NoError(h, err, "timeout checking payment failure reason") + + return payment } // AssertActiveNodesSynced asserts all active nodes have synced to the chain. diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 5fb271afa61..532e639d6aa 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -535,15 +535,23 @@ func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo, } rhash := sha256.Sum256(preimage[:]) + var hash lntypes.Hash + copy(hash[:], rhash[:]) + + attempt, err := channeldb.NewHtlcAttempt( + 1, priv, testRoute, time.Time{}, &hash, + ) + if err != nil { + return nil, nil, lntypes.Preimage{}, err + } + return &channeldb.PaymentCreationInfo{ PaymentIdentifier: rhash, Value: testRoute.ReceiverAmt(), CreationTime: time.Unix(time.Now().Unix(), 0), PaymentRequest: []byte("hola"), }, - &channeldb.NewHtlcAttempt( - 1, priv, testRoute, time.Time{}, nil, - ).HTLCAttemptInfo, preimage, nil + &attempt.HTLCAttemptInfo, preimage, nil } func genPreimage() ([32]byte, error) { diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 180d38a631b..9a741407930 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -24,6 +24,17 @@ import ( // the payment lifecycle is exiting . var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting") +// switchResult is the result sent back from the switch after processing the +// HTLC. +type switchResult struct { + // attempt is the HTLC sent to the switch. + attempt *channeldb.HTLCAttempt + + // result is sent from the switch which contains either a preimage if + // ths HTLC is settled or an error if it's failed. + result *htlcswitch.PaymentResult +} + // paymentLifecycle holds all information about the current state of a payment // needed to resume if from any point. type paymentLifecycle struct { @@ -39,11 +50,9 @@ type paymentLifecycle struct { // to stop. quit chan struct{} - // resultCollected is used to signal that the result of an attempt has - // been collected. A nil error means the attempt is either successful - // or failed with temporary error. Otherwise, we should exit the - // lifecycle loop as a terminal error has occurred. - resultCollected chan error + // resultCollected is used to send the result returned from the switch + // for a given HTLC attempt. + resultCollected chan *switchResult // resultCollector is a function that is used to collect the result of // an HTLC attempt, which is always mounted to `p.collectResultAsync` @@ -66,7 +75,7 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, shardTracker: shardTracker, currentHeight: currentHeight, quit: make(chan struct{}), - resultCollected: make(chan error, 1), + resultCollected: make(chan *switchResult, 1), firstHopCustomRecords: firstHopCustomRecords, } @@ -112,6 +121,13 @@ const ( ) // decideNextStep is used to determine the next step in the payment lifecycle. +// It first checks whether the current state of the payment allows more HTLC +// attempts to be made. If allowed, it will return so the lifecycle can continue +// making new attempts. Otherwise, it checks whether we need to wait for the +// results of already sent attempts. If needed, it will block until one of the +// results is sent back. then process its result here. When there's no need to +// wait for results, the method will exit with `stepExit` such that the payment +// lifecycle loop will terminate. func (p *paymentLifecycle) decideNextStep( payment DBMPPayment) (stateStep, error) { @@ -121,46 +137,53 @@ func (p *paymentLifecycle) decideNextStep( return stepExit, err } - if !allow { - // Check whether we need to wait for results. - wait, err := payment.NeedWaitAttempts() - if err != nil { - return stepExit, err - } + // Exit early we need to make more attempts. + if allow { + return stepProceed, nil + } - // If we are not allowed to make new HTLC attempts and there's - // no need to wait, the lifecycle is done and we can exit. - if !wait { - return stepExit, nil - } + // We cannot make more attempts, we now check whether we need to wait + // for results. + wait, err := payment.NeedWaitAttempts() + if err != nil { + return stepExit, err + } - log.Tracef("Waiting for attempt results for payment %v", - p.identifier) + // If we are not allowed to make new HTLC attempts and there's no need + // to wait, the lifecycle is done and we can exit. + if !wait { + return stepExit, nil + } - // Otherwise we wait for one HTLC attempt then continue - // the lifecycle. - // - // NOTE: we don't check `p.quit` since `decideNextStep` is - // running in the same goroutine as `resumePayment`. - select { - case err := <-p.resultCollected: - // If an error is returned, exit with it. - if err != nil { - return stepExit, err - } + log.Tracef("Waiting for attempt results for payment %v", p.identifier) - log.Tracef("Received attempt result for payment %v", - p.identifier) + // Otherwise we wait for the result for one HTLC attempt then continue + // the lifecycle. + select { + case r := <-p.resultCollected: + log.Tracef("Received attempt result for payment %v", + p.identifier) - case <-p.router.quit: - return stepExit, ErrRouterShuttingDown + // Handle the result here. If there's no error, we will return + // stepSkip and move to the next lifecycle iteration, which will + // refresh the payment and wait for the next attempt result, if + // any. + _, err := p.handleAttemptResult(r.attempt, r.result) + + // We would only get a DB-related error here, which will cause + // us to abort the payment flow. + if err != nil { + return stepExit, err } - return stepSkip, nil + case <-p.quit: + return stepExit, ErrPaymentLifecycleExiting + + case <-p.router.quit: + return stepExit, ErrRouterShuttingDown } - // Otherwise we need to make more attempts. - return stepProceed, nil + return stepSkip, nil } // resumePayment resumes the paymentLifecycle from the current state. @@ -175,20 +198,11 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, // If we had any existing attempts outstanding, we'll start by spinning // up goroutines that'll collect their results and deliver them to the // lifecycle loop below. - payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + payment, err := p.reloadInflightAttempts() if err != nil { return [32]byte{}, nil, err } - for _, a := range payment.InFlightHTLCs() { - a := a - - log.Infof("Resuming HTLC attempt %v for payment %v", - a.AttemptID, p.identifier) - - p.resultCollector(&a) - } - // Get the payment status. status := payment.GetStatus() @@ -211,23 +225,18 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, // critical error during path finding. lifecycle: for { - // We update the payment state on every iteration. Since the - // payment state is affected by multiple goroutines (ie, - // collectResultAsync), it is NOT guaranteed that we always - // have the latest state here. This is fine as long as the - // state is consistent as a whole. - payment, err = p.router.cfg.Control.FetchPayment(p.identifier) + // We update the payment state on every iteration. + currentPayment, ps, err := p.reloadPayment() if err != nil { return exitWithErr(err) } - ps := payment.GetState() - remainingFees := p.calcFeeBudget(ps.FeesPaid) + // Reassign status so it can be read in `exitWithErr`. + status = currentPayment.GetStatus() - status = payment.GetStatus() - log.Debugf("Payment %v: status=%v, active_shards=%v, "+ - "rem_value=%v, fee_limit=%v", p.identifier, status, - ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) + // Reassign payment such that when the lifecycle exits, the + // latest payment can be read when we access its terminal info. + payment = currentPayment // We now proceed our lifecycle with the following tasks in // order, @@ -288,13 +297,6 @@ lifecycle: log.Tracef("Found route: %s", spew.Sdump(rt.Hops)) - // Allow the traffic shaper to add custom records to the - // outgoing HTLC and also adjust the amount if needed. - err = p.amendFirstHopData(rt) - if err != nil { - return exitWithErr(err) - } - // We found a route to try, create a new HTLC attempt to try. attempt, err := p.registerAttempt(rt, ps.RemainingAmt) if err != nil { @@ -391,6 +393,13 @@ func (p *paymentLifecycle) requestRoute( // Exit early if there's no error. if err == nil { + // Allow the traffic shaper to add custom records to the + // outgoing HTLC and also adjust the amount if needed. + err = p.amendFirstHopData(rt) + if err != nil { + return nil, err + } + return rt, nil } @@ -444,62 +453,60 @@ type attemptResult struct { } // collectResultAsync launches a goroutine that will wait for the result of the -// given HTLC attempt to be available then handle its result. Once received, it -// will send a nil error to channel `resultCollected` to indicate there's a -// result. +// given HTLC attempt to be available then save its result in a map. Once +// received, it will send the result returned from the switch to channel +// `resultCollected`. func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { log.Debugf("Collecting result for attempt %v in payment %v", attempt.AttemptID, p.identifier) go func() { - // Block until the result is available. - _, err := p.collectResult(attempt) + result, err := p.collectResult(attempt) if err != nil { - log.Errorf("Error collecting result for attempt %v "+ - "in payment %v: %v", attempt.AttemptID, + log.Errorf("Error collecting result for attempt %v in "+ + "payment %v: %v", attempt.AttemptID, p.identifier, err) + + return } log.Debugf("Result collected for attempt %v in payment %v", attempt.AttemptID, p.identifier) - // Once the result is collected, we signal it by writing the - // error to `resultCollected`. + // Create a switch result and send it to the resultCollected + // chan, which gets processed when the lifecycle is waiting for + // a result to be received in decideNextStep. + r := &switchResult{ + attempt: attempt, + result: result, + } + + // Signal that a result has been collected. select { - // Send the signal or quit. - case p.resultCollected <- err: + // Send the result so decideNextStep can proceed. + case p.resultCollected <- r: case <-p.quit: log.Debugf("Lifecycle exiting while collecting "+ "result for payment %v", p.identifier) case <-p.router.quit: - return } }() } -// collectResult waits for the result for the given attempt to be available -// from the Switch, then records the attempt outcome with the control tower. -// An attemptResult is returned, indicating the final outcome of this HTLC -// attempt. -func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( - *attemptResult, error) { +// collectResult waits for the result of the given HTLC attempt to be sent by +// the switch and returns it. +func (p *paymentLifecycle) collectResult( + attempt *channeldb.HTLCAttempt) (*htlcswitch.PaymentResult, error) { log.Tracef("Collecting result for attempt %v", spew.Sdump(attempt)) - // We'll retrieve the hash specific to this shard from the - // shardTracker, since it will be needed to regenerate the circuit - // below. - hash, err := p.shardTracker.GetHash(attempt.AttemptID) - if err != nil { - return p.failAttempt(attempt.AttemptID, err) - } + result := &htlcswitch.PaymentResult{} // Regenerate the circuit for this attempt. - _, circuit, err := generateSphinxPacket( - &attempt.Route, hash[:], attempt.SessionKey(), - ) + circuit, err := attempt.Circuit() + // TODO(yy): We generate this circuit to create the error decryptor, // which is then used in htlcswitch as the deobfuscator to decode the // error from `UpdateFailHTLC`. However, suppose it's an @@ -512,8 +519,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( if err != nil { log.Debugf("Unable to generate circuit for attempt %v: %v", attempt.AttemptID, err) - - return p.failAttempt(attempt.AttemptID, err) + return nil, err } // Using the created circuit, initialize the error decrypter, so we can @@ -539,22 +545,21 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( log.Errorf("Failed getting result for attemptID %d "+ "from switch: %v", attempt.AttemptID, err) - return p.handleSwitchErr(attempt, err) + result.Error = err + + return result, nil } // The switch knows about this payment, we'll wait for a result to be // available. - var ( - result *htlcswitch.PaymentResult - ok bool - ) - select { - case result, ok = <-resultChan: + case r, ok := <-resultChan: if !ok { return nil, htlcswitch.ErrSwitchExiting } + result = r + case <-p.quit: return nil, ErrPaymentLifecycleExiting @@ -562,46 +567,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( return nil, ErrRouterShuttingDown } - // In case of a payment failure, fail the attempt with the control - // tower and return. - if result.Error != nil { - return p.handleSwitchErr(attempt, result.Error) - } - - // We successfully got a payment result back from the switch. - log.Debugf("Payment %v succeeded with pid=%v", - p.identifier, attempt.AttemptID) - - // Report success to mission control. - err = p.router.cfg.MissionControl.ReportPaymentSuccess( - attempt.AttemptID, &attempt.Route, - ) - if err != nil { - log.Errorf("Error reporting payment success to mc: %v", err) - } - - // In case of success we atomically store settle result to the DB move - // the shard to the settled state. - htlcAttempt, err := p.router.cfg.Control.SettleAttempt( - p.identifier, attempt.AttemptID, - &channeldb.HTLCSettleInfo{ - Preimage: result.Preimage, - SettleTime: p.router.cfg.Clock.Now(), - }, - ) - if err != nil { - log.Errorf("Error settling attempt %v for payment %v with "+ - "preimage %v: %v", attempt.AttemptID, p.identifier, - result.Preimage, err) - - // We won't mark the attempt as failed since we already have - // the preimage. - return nil, err - } - - return &attemptResult{ - attempt: htlcAttempt, - }, nil + return result, nil } // registerAttempt is responsible for creating and saving an HTLC attempt in db @@ -675,11 +641,9 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, // We now have all the information needed to populate the current // attempt information. - attempt := channeldb.NewHtlcAttempt( + return channeldb.NewHtlcAttempt( attemptID, sessionKey, *rt, p.router.cfg.Clock.Now(), &hash, ) - - return attempt, nil } // sendAttempt attempts to send the current attempt to the switch to complete @@ -711,9 +675,7 @@ func (p *paymentLifecycle) sendAttempt( // Generate the raw encoded sphinx packet to be included along // with the htlcAdd message that we send directly to the // switch. - onionBlob, _, err := generateSphinxPacket( - &rt, attempt.Hash[:], attempt.SessionKey(), - ) + onionBlob, err := attempt.OnionBlob() if err != nil { log.Errorf("Failed to create onion blob: attempt=%d in "+ "payment=%v, err:%v", attempt.AttemptID, @@ -722,7 +684,7 @@ func (p *paymentLifecycle) sendAttempt( return p.failAttempt(attempt.AttemptID, err) } - copy(htlcAdd.OnionBlob[:], onionBlob) + htlcAdd.OnionBlob = onionBlob // Send it to the Switch. When this method returns we assume // the Switch successfully has persisted the payment attempt, @@ -885,8 +847,8 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, // case we can safely send a new payment attempt, and wait for its // result to be available. if errors.Is(sendErr, htlcswitch.ErrPaymentIDNotFound) { - log.Debugf("Attempt ID %v for payment %v not found in the "+ - "Switch, retrying.", attempt.AttemptID, p.identifier) + log.Warnf("Failing attempt=%v for payment=%v as it's not "+ + "found in the Switch", attempt.AttemptID, p.identifier) return p.failAttempt(attemptID, sendErr) } @@ -1097,3 +1059,106 @@ func marshallError(sendError error, time time.Time) *channeldb.HTLCFailInfo { return response } + +// reloadInflightAttempts is called when the payment lifecycle is resumed after +// a restart. It reloads all inflight attempts from the control tower and +// collects the results of the attempts that have been sent before. +func (p *paymentLifecycle) reloadInflightAttempts() (DBMPPayment, error) { + payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + if err != nil { + return nil, err + } + + for _, a := range payment.InFlightHTLCs() { + a := a + + log.Infof("Resuming HTLC attempt %v for payment %v", + a.AttemptID, p.identifier) + + p.resultCollector(&a) + } + + return payment, nil +} + +// reloadPayment returns the latest payment found in the db (control tower). +func (p *paymentLifecycle) reloadPayment() (DBMPPayment, + *channeldb.MPPaymentState, error) { + + // Read the db to get the latest state of the payment. + payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + if err != nil { + return nil, nil, err + } + + ps := payment.GetState() + remainingFees := p.calcFeeBudget(ps.FeesPaid) + + log.Debugf("Payment %v: status=%v, active_shards=%v, rem_value=%v, "+ + "fee_limit=%v", p.identifier, payment.GetStatus(), + ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) + + return payment, ps, nil +} + +// handleAttemptResult processes the result of an HTLC attempt returned from +// the htlcswitch. +func (p *paymentLifecycle) handleAttemptResult(attempt *channeldb.HTLCAttempt, + result *htlcswitch.PaymentResult) (*attemptResult, error) { + + // If the result has an error, we need to further process it by failing + // the attempt and maybe fail the payment. + if result.Error != nil { + return p.handleSwitchErr(attempt, result.Error) + } + + // We got an attempt settled result back from the switch. + log.Debugf("Payment(%v): attempt(%v) succeeded", p.identifier, + attempt.AttemptID) + + // Report success to mission control. + err := p.router.cfg.MissionControl.ReportPaymentSuccess( + attempt.AttemptID, &attempt.Route, + ) + if err != nil { + log.Errorf("Error reporting payment success to mc: %v", err) + } + + // In case of success we atomically store settle result to the DB and + // move the shard to the settled state. + htlcAttempt, err := p.router.cfg.Control.SettleAttempt( + p.identifier, attempt.AttemptID, + &channeldb.HTLCSettleInfo{ + Preimage: result.Preimage, + SettleTime: p.router.cfg.Clock.Now(), + }, + ) + if err != nil { + log.Errorf("Error settling attempt %v for payment %v with "+ + "preimage %v: %v", attempt.AttemptID, p.identifier, + result.Preimage, err) + + // We won't mark the attempt as failed since we already have + // the preimage. + return nil, err + } + + return &attemptResult{ + attempt: htlcAttempt, + }, nil +} + +// collectAndHandleResult waits for the result for the given attempt to be +// available from the Switch, then records the attempt outcome with the control +// tower. An attemptResult is returned, indicating the final outcome of this +// HTLC attempt. +func (p *paymentLifecycle) collectAndHandleResult( + attempt *channeldb.HTLCAttempt) (*attemptResult, error) { + + result, err := p.collectResult(attempt) + if err != nil { + return nil, err + } + + return p.handleAttemptResult(attempt, result) +} diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 72aa6314194..b04b04d4be3 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -260,10 +260,15 @@ func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route { func makeSettledAttempt(t *testing.T, total int, preimage lntypes.Preimage) *channeldb.HTLCAttempt { - return &channeldb.HTLCAttempt{ + a := &channeldb.HTLCAttempt{ HTLCAttemptInfo: makeAttemptInfo(t, total), Settle: &channeldb.HTLCSettleInfo{Preimage: preimage}, } + + hash := preimage.Hash() + a.Hash = &hash + + return a } func makeFailedAttempt(t *testing.T, total int) *channeldb.HTLCAttempt { @@ -279,6 +284,7 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo { rt := createDummyRoute(t, lnwire.MilliSatoshi(amt)) return channeldb.HTLCAttemptInfo{ Route: *rt, + Hash: &lntypes.Hash{1, 2, 3}, } } @@ -516,7 +522,8 @@ func TestRequestRouteFailPaymentError(t *testing.T) { ct.AssertExpectations(t) } -// TestDecideNextStep checks the method `decideNextStep` behaves as expected. +// TestDecideNextStep checks the method `decideNextStep` behaves as expected +// given the returned values from `AllowMoreAttempts` and `NeedWaitAttempts`. func TestDecideNextStep(t *testing.T) { t.Parallel() @@ -531,15 +538,8 @@ func TestDecideNextStep(t *testing.T) { name string allowMoreAttempts *mockReturn needWaitAttempts *mockReturn - - // When the attemptResultChan has returned. - closeResultChan bool - - // Whether the router has quit. - routerQuit bool - - expectedStep stateStep - expectedErr error + expectedStep stateStep + expectedErr error }{ { name: "allow more attempts", @@ -548,52 +548,36 @@ func TestDecideNextStep(t *testing.T) { expectedErr: nil, }, { - name: "error on allow more attempts", + name: "error checking allow more attempts", allowMoreAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, expectedErr: errDummy, }, { - name: "no wait and exit", + name: "no need to wait attempts", allowMoreAttempts: &mockReturn{false, nil}, needWaitAttempts: &mockReturn{false, nil}, expectedStep: stepExit, expectedErr: nil, }, { - name: "wait returns an error", + name: "error checking wait attempts", allowMoreAttempts: &mockReturn{false, nil}, needWaitAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, expectedErr: errDummy, }, - - { - name: "wait and exit on result chan", - allowMoreAttempts: &mockReturn{false, nil}, - needWaitAttempts: &mockReturn{true, nil}, - closeResultChan: true, - expectedStep: stepSkip, - expectedErr: nil, - }, - { - name: "wait and exit on router quit", - allowMoreAttempts: &mockReturn{false, nil}, - needWaitAttempts: &mockReturn{true, nil}, - routerQuit: true, - expectedStep: stepExit, - expectedErr: ErrRouterShuttingDown, - }, } for _, tc := range testCases { tc := tc // Create a test paymentLifecycle. - p := createTestPaymentLifecycle() + p, _ := newTestPaymentLifecycle(t) // Make a mock payment. payment := &mockMPPayment{} + defer payment.AssertExpectations(t) // Mock the method AllowMoreAttempts. payment.On("AllowMoreAttempts").Return( @@ -609,27 +593,188 @@ func TestDecideNextStep(t *testing.T) { ).Once() } - // Send a nil error to the attemptResultChan if requested. - if tc.closeResultChan { - p.resultCollected = make(chan error, 1) - p.resultCollected <- nil - } - - // Quit the router if requested. - if tc.routerQuit { - close(p.router.quit) - } - // Once the setup is finished, run the test cases. t.Run(tc.name, func(t *testing.T) { step, err := p.decideNextStep(payment) require.Equal(t, tc.expectedStep, step) require.ErrorIs(t, tc.expectedErr, err) }) + } +} + +// TestDecideNextStepOnRouterQuit checks the method `decideNextStep` behaves as +// expected when the router is quit. +func TestDecideNextStepOnRouterQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle. + p, _ := newTestPaymentLifecycle(t) + + // Make a mock payment. + payment := &mockMPPayment{} + defer payment.AssertExpectations(t) + + // Mock the method AllowMoreAttempts to return false. + payment.On("AllowMoreAttempts").Return(false, nil).Once() + + // Mock the method NeedWaitAttempts to wait for results. + payment.On("NeedWaitAttempts").Return(true, nil).Once() + + // Quit the router. + close(p.router.quit) + + // Call the method under test. + step, err := p.decideNextStep(payment) + + // We expect stepExit and an error to be returned. + require.Equal(t, stepExit, step) + require.ErrorIs(t, err, ErrRouterShuttingDown) +} + +// TestDecideNextStepOnLifecycleQuit checks the method `decideNextStep` behaves +// as expected when the lifecycle is quit. +func TestDecideNextStepOnLifecycleQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle. + p, _ := newTestPaymentLifecycle(t) + + // Make a mock payment. + payment := &mockMPPayment{} + defer payment.AssertExpectations(t) + + // Mock the method AllowMoreAttempts to return false. + payment.On("AllowMoreAttempts").Return(false, nil).Once() + + // Mock the method NeedWaitAttempts to wait for results. + payment.On("NeedWaitAttempts").Return(true, nil).Once() + + // Quit the paymentLifecycle. + close(p.quit) - // Check the payment's methods are called as expected. - payment.AssertExpectations(t) + // Call the method under test. + step, err := p.decideNextStep(payment) + + // We expect stepExit and an error to be returned. + require.Equal(t, stepExit, step) + require.ErrorIs(t, err, ErrPaymentLifecycleExiting) +} + +// TestDecideNextStepHandleAttemptResultSucceed checks the method +// `decideNextStep` behaves as expected when successfully handled the attempt +// result. +func TestDecideNextStepHandleAttemptResultSucceed(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle. + p, m := newTestPaymentLifecycle(t) + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Make a mock payment. + payment := &mockMPPayment{} + defer payment.AssertExpectations(t) + + // Mock the method AllowMoreAttempts to return false. + payment.On("AllowMoreAttempts").Return(false, nil).Once() + + // Mock the method NeedWaitAttempts to wait for results. + payment.On("NeedWaitAttempts").Return(true, nil).Once() + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Create a result that contains a preimage. + result := &htlcswitch.PaymentResult{ + Preimage: preimage, } + + // Create a switch result and send it to the `resultCollected`` chan. + r := &switchResult{ + attempt: attempt, + result: result, + } + p.resultCollected <- r + + // We now mock the behavior of `handleAttemptResult` - we are not + // testing this method's behavior here, so we simply mock it to return + // no error. + // + // Since the result doesn't contain an error, `ReportPaymentSuccess` + // should be called. + m.missionControl.On("ReportPaymentSuccess", mock.Anything, + mock.Anything).Return(nil).Once() + + // The settled htlc should be returned from `SettleAttempt`. + m.control.On("SettleAttempt", mock.Anything, mock.Anything, + mock.Anything).Return(attempt, nil).Once() + + // Call the method under test. + step, err := p.decideNextStep(payment) + + // We expect stepSkip and no error to be returned. + require.Equal(t, stepSkip, step) + require.NoError(t, err) +} + +// TestDecideNextStepHandleAttemptResultFail checks the method `decideNextStep` +// behaves as expected when it fails to handle the attempt result. +func TestDecideNextStepHandleAttemptResultFail(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle. + p, m := newTestPaymentLifecycle(t) + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Make a mock payment. + payment := &mockMPPayment{} + defer payment.AssertExpectations(t) + + // Mock the method AllowMoreAttempts to return false. + payment.On("AllowMoreAttempts").Return(false, nil).Once() + + // Mock the method NeedWaitAttempts to wait for results. + payment.On("NeedWaitAttempts").Return(true, nil).Once() + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Create a result that contains a preimage. + result := &htlcswitch.PaymentResult{ + Preimage: preimage, + } + + // Create a switch result and send it to the `resultCollected`` chan. + r := &switchResult{ + attempt: attempt, + result: result, + } + p.resultCollected <- r + + // We now mock the behavior of `handleAttemptResult` - we are not + // testing this method's behavior here, so we simply mock it to return + // an error. + // + // Since the result doesn't contain an error, `ReportPaymentSuccess` + // should be called. + m.missionControl.On("ReportPaymentSuccess", + mock.Anything, mock.Anything).Return(nil).Once() + + // Mock SettleAttempt to return an error. + m.control.On("SettleAttempt", mock.Anything, mock.Anything, + mock.Anything).Return(attempt, errDummy).Once() + + // Call the method under test. + step, err := p.decideNextStep(payment) + + // We expect stepExit and the above error to be returned. + require.Equal(t, stepExit, step) + require.ErrorIs(t, err, errDummy) } // TestResumePaymentFailOnFetchPayment checks when we fail to fetch the @@ -1303,11 +1448,6 @@ func TestCollectResultExitOnErr(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a dummy error. m.payer.On("GetAttemptResult", attempt.AttemptID, p.identifier, mock.Anything, @@ -1332,7 +1472,7 @@ func TestCollectResultExitOnErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } @@ -1348,11 +1488,6 @@ func TestCollectResultExitOnResultErr(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1383,7 +1518,7 @@ func TestCollectResultExitOnResultErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } @@ -1399,11 +1534,6 @@ func TestCollectResultExitOnSwitchQuit(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1414,7 +1544,7 @@ func TestCollectResultExitOnSwitchQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, htlcswitch.ErrSwitchExiting, "expected switch exit") require.Nil(t, result, "expected nil attempt") @@ -1431,11 +1561,6 @@ func TestCollectResultExitOnRouterQuit(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1446,7 +1571,7 @@ func TestCollectResultExitOnRouterQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, ErrRouterShuttingDown, "expected router exit") require.Nil(t, result, "expected nil attempt") } @@ -1462,11 +1587,6 @@ func TestCollectResultExitOnLifecycleQuit(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1477,7 +1597,7 @@ func TestCollectResultExitOnLifecycleQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, ErrPaymentLifecycleExiting, "expected lifecycle exit") require.Nil(t, result, "expected nil attempt") @@ -1495,11 +1615,6 @@ func TestCollectResultExitOnSettleErr(t *testing.T) { preimage := lntypes.Preimage{1} attempt := makeSettledAttempt(t, paymentAmt, preimage) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1526,7 +1641,7 @@ func TestCollectResultExitOnSettleErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, errDummy, "expected settle error") require.Nil(t, result, "expected nil attempt") } @@ -1542,11 +1657,6 @@ func TestCollectResultSuccess(t *testing.T) { preimage := lntypes.Preimage{1} attempt := makeSettledAttempt(t, paymentAmt, preimage) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1573,7 +1683,7 @@ func TestCollectResultSuccess(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.NoError(t, err, "expected no error") require.Equal(t, preimage, result.attempt.Settle.Preimage, "preimage mismatch") @@ -1590,10 +1700,10 @@ func TestCollectResultAsyncSuccess(t *testing.T) { preimage := lntypes.Preimage{1} attempt := makeSettledAttempt(t, paymentAmt, preimage) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() + // Create a mock result returned from the switch. + result := &htlcswitch.PaymentResult{ + Preimage: preimage, + } // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) @@ -1601,13 +1711,86 @@ func TestCollectResultAsyncSuccess(t *testing.T) { attempt.AttemptID, p.identifier, mock.Anything, ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { // Send the preimage to the result chan. - resultChan <- &htlcswitch.PaymentResult{ - Preimage: preimage, - } + resultChan <- result }) - // Once the result is received, `ReportPaymentSuccess` should be - // called. + // Now call the method under test. + p.collectResultAsync(attempt) + + var r *switchResult + + // Assert the result is returned within 5 seconds. + waitErr := wait.NoError(func() error { + r = <-p.resultCollected + return nil + }, testTimeout) + require.NoError(t, waitErr, "timeout waiting for result") + + // Assert the result is received as expected. + require.Equal(t, attempt, r.attempt) + require.Equal(t, result, r.result) +} + +// TestHandleAttemptResultWithError checks that when the `Error` field in the +// result is not nil, it's properly handled by `handleAttemptResult`. +func TestHandleAttemptResultWithError(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Create a result that contains an error. + // + // NOTE: The error is chosen so we can quickly exit `handleSwitchErr` + // since we are not testing its behavior here. + result := &htlcswitch.PaymentResult{ + Error: htlcswitch.ErrPaymentIDNotFound, + } + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd cancel the shard and fail the attempt. + // + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a dummy error. + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Call the method under test and expect the dummy error to be + // returned. + attemptResult, err := p.handleAttemptResult(attempt, result) + require.ErrorIs(t, err, errDummy, "expected fail error") + require.Nil(t, attemptResult, "expected nil attempt result") +} + +// TestHandleAttemptResultSuccess checks that when the result contains no error +// but a preimage, it's handled correctly by `handleAttemptResult`. +func TestHandleAttemptResultSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Create a result that contains a preimage. + result := &htlcswitch.PaymentResult{ + Preimage: preimage, + } + + // Since the result doesn't contain an error, `ReportPaymentSuccess` + // should be called. m.missionControl.On("ReportPaymentSuccess", attempt.AttemptID, &attempt.Route, ).Return(nil).Once() @@ -1620,17 +1803,9 @@ func TestCollectResultAsyncSuccess(t *testing.T) { // Mock the clock to return a current time. m.clock.On("Now").Return(time.Now()) - // Now call the method under test. - p.collectResultAsync(attempt) - - // Assert the result is returned within 5 seconds. - var err error - waitErr := wait.NoError(func() error { - err = <-p.resultCollected - return nil - }, testTimeout) - require.NoError(t, waitErr, "timeout waiting for result") - - // Assert that a nil error is received. + // Call the method under test and expect the dummy error to be + // returned. + attemptResult, err := p.handleAttemptResult(attempt, result) require.NoError(t, err, "expected no error") + require.Equal(t, attempt, attemptResult.attempt) } diff --git a/routing/router.go b/routing/router.go index 468510a6c7b..323fe761671 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1,7 +1,6 @@ package routing import ( - "bytes" "context" "fmt" "math" @@ -15,7 +14,6 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" @@ -722,71 +720,6 @@ func generateNewSessionKey() (*btcec.PrivateKey, error) { return btcec.NewPrivateKey() } -// generateSphinxPacket generates then encodes a sphinx packet which encodes -// the onion route specified by the passed layer 3 route. The blob returned -// from this function can immediately be included within an HTLC add packet to -// be sent to the first hop within the route. -func generateSphinxPacket(rt *route.Route, paymentHash []byte, - sessionKey *btcec.PrivateKey) ([]byte, *sphinx.Circuit, error) { - - // Now that we know we have an actual route, we'll map the route into a - // sphinx payment path which includes per-hop payloads for each hop - // that give each node within the route the necessary information - // (fees, CLTV value, etc.) to properly forward the payment. - sphinxPath, err := rt.ToSphinxPath() - if err != nil { - return nil, nil, err - } - - log.Tracef("Constructed per-hop payloads for payment_hash=%x: %v", - paymentHash, lnutils.NewLogClosure(func() string { - path := make( - []sphinx.OnionHop, sphinxPath.TrueRouteLength(), - ) - for i := range path { - hopCopy := sphinxPath[i] - path[i] = hopCopy - } - - return spew.Sdump(path) - }), - ) - - // Next generate the onion routing packet which allows us to perform - // privacy preserving source routing across the network. - sphinxPacket, err := sphinx.NewOnionPacket( - sphinxPath, sessionKey, paymentHash, - sphinx.DeterministicPacketFiller, - ) - if err != nil { - return nil, nil, err - } - - // Finally, encode Sphinx packet using its wire representation to be - // included within the HTLC add packet. - var onionBlob bytes.Buffer - if err := sphinxPacket.Encode(&onionBlob); err != nil { - return nil, nil, err - } - - log.Tracef("Generated sphinx packet: %v", - lnutils.NewLogClosure(func() string { - // We make a copy of the ephemeral key and unset the - // internal curve here in order to keep the logs from - // getting noisy. - key := *sphinxPacket.EphemeralKey - packetCopy := *sphinxPacket - packetCopy.EphemeralKey = &key - return spew.Sdump(packetCopy) - }), - ) - - return onionBlob.Bytes(), &sphinx.Circuit{ - SessionKey: sessionKey, - PaymentPath: sphinxPath.NodeKeys(), - }, nil -} - // LightningPayment describes a payment to be sent through the network to the // final destination. type LightningPayment struct { @@ -1253,7 +1186,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // The attempt was successfully sent, wait for the result to be // available. - result, err = p.collectResult(attempt) + result, err = p.collectAndHandleResult(attempt) if err != nil { return nil, err } diff --git a/routing/router_test.go b/routing/router_test.go index 22c9d14e507..543f3b00056 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -48,13 +48,6 @@ var ( testFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) - testHash = [32]byte{ - 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, - 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, - 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, - 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, - } - testTime = time.Date(2018, time.January, 9, 14, 00, 00, 0, time.UTC) priv1, _ = btcec.NewPrivateKey() @@ -1235,18 +1228,6 @@ func TestFindPathFeeWeighting(t *testing.T) { require.Equal(t, ctx.aliases["luoji"], path[0].policy.ToNodePubKey()) } -// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket -// function is able to gracefully handle being passed a nil set of hops for the -// route by the caller. -func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { - t.Parallel() - - sessionKey, _ := btcec.NewPrivateKey() - emptyRoute := &route.Route{} - _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) - require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) -} - // TestUnknownErrorSource tests that if the source of an error is unknown, all // edges along the route will be pruned. func TestUnknownErrorSource(t *testing.T) {