diff --git a/chainntnfs/bitcoindnotify/bitcoind.go b/chainntnfs/bitcoindnotify/bitcoind.go index d268c8374a..ab1db1d846 100644 --- a/chainntnfs/bitcoindnotify/bitcoind.go +++ b/chainntnfs/bitcoindnotify/bitcoind.go @@ -15,6 +15,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/queue" ) @@ -1070,3 +1071,26 @@ func (b *BitcoindNotifier) CancelMempoolSpendEvent( b.memNotifier.UnsubscribeEvent(sub) } + +// LookupInputMempoolSpend takes an outpoint and queries the mempool to find +// its spending tx. Returns the tx if found, otherwise fn.None. +// +// NOTE: part of the MempoolWatcher interface. +func (b *BitcoindNotifier) LookupInputMempoolSpend( + op wire.OutPoint) fn.Option[wire.MsgTx] { + + // Find the spending txid. + txid, found := b.chainConn.LookupInputMempoolSpend(op) + if !found { + return fn.None[wire.MsgTx]() + } + + // Query the spending tx using the id. + tx, err := b.chainConn.GetRawTransaction(&txid) + if err != nil { + // TODO(yy): enable logging errors in this package. + return fn.None[wire.MsgTx]() + } + + return fn.Some(*tx.MsgTx().Copy()) +} diff --git a/chainntnfs/btcdnotify/btcd.go b/chainntnfs/btcdnotify/btcd.go index 430a106614..f7642f4a77 100644 --- a/chainntnfs/btcdnotify/btcd.go +++ b/chainntnfs/btcdnotify/btcd.go @@ -14,8 +14,10 @@ import ( "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" "github.com/lightningnetwork/lnd/blockcache" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/queue" ) @@ -58,7 +60,7 @@ type BtcdNotifier struct { active int32 // To be used atomically. stopped int32 // To be used atomically. - chainConn *rpcclient.Client + chainConn *chain.RPCClient chainParams *chaincfg.Params notificationCancels chan interface{} @@ -127,21 +129,30 @@ func New(config *rpcclient.ConnConfig, chainParams *chaincfg.Params, quit: make(chan struct{}), } + // Disable connecting to btcd within the rpcclient.New method. We + // defer establishing the connection to our .Start() method. + config.DisableConnectOnNew = true + config.DisableAutoReconnect = false + ntfnCallbacks := &rpcclient.NotificationHandlers{ OnBlockConnected: notifier.onBlockConnected, OnBlockDisconnected: notifier.onBlockDisconnected, OnRedeemingTx: notifier.onRedeemingTx, } - // Disable connecting to btcd within the rpcclient.New method. We - // defer establishing the connection to our .Start() method. - config.DisableConnectOnNew = true - config.DisableAutoReconnect = false - chainConn, err := rpcclient.New(config, ntfnCallbacks) + rpcCfg := &chain.RPCClientConfig{ + ReconnectAttempts: 20, + Conn: config, + Chain: chainParams, + NotificationHandlers: ntfnCallbacks, + } + + chainRPC, err := chain.NewRPCClientWithConfig(rpcCfg) if err != nil { return nil, err } - notifier.chainConn = chainConn + + notifier.chainConn = chainRPC return notifier, nil } @@ -1127,3 +1138,26 @@ func (b *BtcdNotifier) CancelMempoolSpendEvent( b.memNotifier.UnsubscribeEvent(sub) } + +// LookupInputMempoolSpend takes an outpoint and queries the mempool to find +// its spending tx. Returns the tx if found, otherwise fn.None. +// +// NOTE: part of the MempoolWatcher interface. +func (b *BtcdNotifier) LookupInputMempoolSpend( + op wire.OutPoint) fn.Option[wire.MsgTx] { + + // Find the spending txid. + txid, found := b.chainConn.LookupInputMempoolSpend(op) + if !found { + return fn.None[wire.MsgTx]() + } + + // Query the spending tx using the id. + tx, err := b.chainConn.GetRawTransaction(&txid) + if err != nil { + // TODO(yy): enable logging errors in this package. + return fn.None[wire.MsgTx]() + } + + return fn.Some(*tx.MsgTx().Copy()) +} diff --git a/chainntnfs/interface.go b/chainntnfs/interface.go index 0f2fe27e45..0a7d539982 100644 --- a/chainntnfs/interface.go +++ b/chainntnfs/interface.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" ) var ( @@ -848,4 +849,9 @@ type MempoolWatcher interface { // CancelMempoolSpendEvent allows the caller to cancel a subscription to // watch for a spend of an outpoint in the mempool. CancelMempoolSpendEvent(sub *MempoolSpendEvent) + + // LookupInputMempoolSpend looks up the mempool to find a spending tx + // which spends the given outpoint. A fn.None is returned if it's not + // found. + LookupInputMempoolSpend(op wire.OutPoint) fn.Option[wire.MsgTx] } diff --git a/chainntnfs/mocks.go b/chainntnfs/mocks.go new file mode 100644 index 0000000000..31b75d46f2 --- /dev/null +++ b/chainntnfs/mocks.go @@ -0,0 +1,52 @@ +package chainntnfs + +import ( + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" + "github.com/stretchr/testify/mock" +) + +// MockMempoolWatcher is a mock implementation of the MempoolWatcher interface. +// This is used by other subsystems to mock the behavior of the mempool +// watcher. +type MockMempoolWatcher struct { + mock.Mock +} + +// NewMockMempoolWatcher returns a new instance of a mock mempool watcher. +func NewMockMempoolWatcher() *MockMempoolWatcher { + return &MockMempoolWatcher{} +} + +// Compile-time check to ensure MockMempoolWatcher implements MempoolWatcher. +var _ MempoolWatcher = (*MockMempoolWatcher)(nil) + +// SubscribeMempoolSpent implements the MempoolWatcher interface. +func (m *MockMempoolWatcher) SubscribeMempoolSpent( + op wire.OutPoint) (*MempoolSpendEvent, error) { + + args := m.Called(op) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*MempoolSpendEvent), args.Error(1) +} + +// CancelMempoolSpendEvent implements the MempoolWatcher interface. +func (m *MockMempoolWatcher) CancelMempoolSpendEvent( + sub *MempoolSpendEvent) { + + m.Called(sub) +} + +// LookupInputMempoolSpend looks up the mempool to find a spending tx which +// spends the given outpoint. +func (m *MockMempoolWatcher) LookupInputMempoolSpend( + op wire.OutPoint) fn.Option[wire.MsgTx] { + + args := m.Called(op) + + return args.Get(0).(fn.Option[wire.MsgTx]) +} diff --git a/cmd/lncli/walletrpc_types.go b/cmd/lncli/walletrpc_types.go index 09b3ec69a9..b6680a6ede 100644 --- a/cmd/lncli/walletrpc_types.go +++ b/cmd/lncli/walletrpc_types.go @@ -5,15 +5,16 @@ import "github.com/lightningnetwork/lnd/lnrpc/walletrpc" // PendingSweep is a CLI-friendly type of the walletrpc.PendingSweep proto. We // use this to show more useful string versions of byte slices and enums. type PendingSweep struct { - OutPoint OutPoint `json:"outpoint"` - WitnessType string `json:"witness_type"` - AmountSat uint32 `json:"amount_sat"` - SatPerVByte uint32 `json:"sat_per_vbyte"` - BroadcastAttempts uint32 `json:"broadcast_attempts"` - NextBroadcastHeight uint32 `json:"next_broadcast_height"` - RequestedSatPerVByte uint32 `json:"requested_sat_per_vbyte"` - RequestedConfTarget uint32 `json:"requested_conf_target"` - Force bool `json:"force"` + OutPoint OutPoint `json:"outpoint"` + WitnessType string `json:"witness_type"` + AmountSat uint32 `json:"amount_sat"` + SatPerVByte uint32 `json:"sat_per_vbyte"` + BroadcastAttempts uint32 `json:"broadcast_attempts"` + // TODO(yy): deprecate. + NextBroadcastHeight uint32 `json:"next_broadcast_height"` + RequestedSatPerVByte uint32 `json:"requested_sat_per_vbyte"` + RequestedConfTarget uint32 `json:"requested_conf_target"` + Force bool `json:"force"` } // NewPendingSweepFromProto converts the walletrpc.PendingSweep proto type into diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index 6aa25cd228..80e665dd72 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -145,17 +145,6 @@ func (c *anchorResolver) Resolve() (ContractResolver, error) { c.log.Warnf("our anchor spent by someone else") outcome = channeldb.ResolverOutcomeUnclaimed - // The sweeper gave up on sweeping the anchor. This happens - // after the maximum number of sweep attempts has been reached. - // See sweep.DefaultMaxSweepAttempts. Sweep attempts are - // interspaced with random delays picked from a range that - // increases exponentially. - // - // We consider the anchor as being lost. - case sweep.ErrTooManyAttempts: - c.log.Warnf("anchor sweep abandoned") - outcome = channeldb.ResolverOutcomeUnclaimed - // An unexpected error occurred. default: c.log.Errorf("unable to sweep anchor: %v", sweepRes.Err) diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index b40b4e249c..5d9f809fec 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -343,6 +343,10 @@ bitcoin peers' feefilter values into account](https://github.com/lightningnetwor * Bump sqlite version to [fix a data race](https://github.com/lightningnetwork/lnd/pull/8567). +* The pending inputs in the sweeper is now + [stateful](https://github.com/lightningnetwork/lnd/pull/8423) to better + manage the lifecycle of the inputs. + ## Breaking Changes ## Performance Improvements diff --git a/itest/lnd_open_channel_test.go b/itest/lnd_open_channel_test.go index 919e6ae97f..d28a211dbd 100644 --- a/itest/lnd_open_channel_test.go +++ b/itest/lnd_open_channel_test.go @@ -829,12 +829,19 @@ func testSimpleTaprootChannelActivation(ht *lntest.HarnessTest) { // up as locked balance in the WalletBalance response. func testOpenChannelLockedBalance(ht *lntest.HarnessTest) { var ( - alice = ht.Alice - bob = ht.Bob - req *lnrpc.ChannelAcceptRequest - err error + bob = ht.Bob + req *lnrpc.ChannelAcceptRequest + err error ) + // Create a new node so we can assert exactly how much fund has been + // locked later. + alice := ht.NewNode("alice", nil) + ht.FundCoins(btcutil.SatoshiPerBitcoin, alice) + + // Connect the nodes. + ht.EnsureConnected(alice, bob) + // We first make sure Alice has no locked wallet balance. balance := alice.RPC.WalletBalance() require.EqualValues(ht, 0, balance.LockedBalance) @@ -851,6 +858,7 @@ func testOpenChannelLockedBalance(ht *lntest.HarnessTest) { openChannelReq := &lnrpc.OpenChannelRequest{ NodePubkey: bob.PubKey[:], LocalFundingAmount: int64(funding.MaxBtcFundingAmount), + TargetConf: 6, } _ = alice.RPC.OpenChannel(openChannelReq) @@ -862,8 +870,7 @@ func testOpenChannelLockedBalance(ht *lntest.HarnessTest) { }, defaultTimeout) require.NoError(ht, err) - balance = alice.RPC.WalletBalance() - require.NotEqualValues(ht, 0, balance.LockedBalance) + ht.AssertWalletLockedBalance(alice, btcutil.SatoshiPerBitcoin) // Next, we let Bob deny the request. resp := &lnrpc.ChannelAcceptResponse{ @@ -876,6 +883,5 @@ func testOpenChannelLockedBalance(ht *lntest.HarnessTest) { require.NoError(ht, err) // Finally, we check to make sure the balance is unlocked again. - balance = alice.RPC.WalletBalance() - require.EqualValues(ht, 0, balance.LockedBalance) + ht.AssertWalletLockedBalance(alice, 0) } diff --git a/lnrpc/walletrpc/walletkit_server.go b/lnrpc/walletrpc/walletkit_server.go index ac65974bb9..abd19cd499 100644 --- a/lnrpc/walletrpc/walletkit_server.go +++ b/lnrpc/walletrpc/walletkit_server.go @@ -875,7 +875,6 @@ func (w *WalletKit) PendingSweeps(ctx context.Context, amountSat := uint32(pendingInput.Amount) satPerVbyte := uint64(pendingInput.LastFeeRate.FeePerVByte()) broadcastAttempts := uint32(pendingInput.BroadcastAttempts) - nextBroadcastHeight := uint32(pendingInput.NextBroadcastHeight) feePref := pendingInput.Params.Fee requestedFee, ok := feePref.(sweep.FeeEstimateInfo) @@ -892,7 +891,6 @@ func (w *WalletKit) PendingSweeps(ctx context.Context, AmountSat: amountSat, SatPerVbyte: satPerVbyte, BroadcastAttempts: broadcastAttempts, - NextBroadcastHeight: nextBroadcastHeight, RequestedSatPerVbyte: requestedFeeRate, RequestedConfTarget: requestedFee.ConfTarget, Force: pendingInput.Params.Force, diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index 2a19cd85df..2f4ebeb7f8 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -2572,3 +2572,22 @@ func (h *HarnessTest) MineClosingTx(cp *lnrpc.ChannelPoint, return closeTx } + +// AssertWalletLockedBalance asserts the expected amount has been marked as +// locked in the node's WalletBalance response. +func (h *HarnessTest) AssertWalletLockedBalance(hn *node.HarnessNode, + balance int64) { + + err := wait.NoError(func() error { + balanceResp := hn.RPC.WalletBalance() + got := balanceResp.LockedBalance + + if got != balance { + return fmt.Errorf("want %d, got %d", balance, got) + } + + return nil + }, wait.DefaultTimeout) + require.NoError(h, err, "%s: timeout checking locked balance", + hn.Name()) +} diff --git a/server.go b/server.go index f3d2cda98c..10ad4a0e2f 100644 --- a/server.go +++ b/server.go @@ -1068,18 +1068,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ) s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ - FeeEstimator: cc.FeeEstimator, - GenSweepScript: newSweepPkScriptGen(cc.Wallet), - Signer: cc.Wallet.Cfg.Signer, - Wallet: newSweeperWallet(cc.Wallet), - TickerDuration: cfg.Sweeper.BatchWindowDuration, - Notifier: cc.ChainNotifier, - Store: sweeperStore, - MaxInputsPerTx: sweep.DefaultMaxInputsPerTx, - MaxSweepAttempts: sweep.DefaultMaxSweepAttempts, - NextAttemptDeltaFunc: sweep.DefaultNextAttemptDeltaFunc, - MaxFeeRate: cfg.Sweeper.MaxFeeRate, - Aggregator: aggregator, + FeeEstimator: cc.FeeEstimator, + GenSweepScript: newSweepPkScriptGen(cc.Wallet), + Signer: cc.Wallet.Cfg.Signer, + Wallet: newSweeperWallet(cc.Wallet), + TickerDuration: cfg.Sweeper.BatchWindowDuration, + Mempool: cc.MempoolNotifier, + Notifier: cc.ChainNotifier, + Store: sweeperStore, + MaxInputsPerTx: sweep.DefaultMaxInputsPerTx, + MaxSweepAttempts: sweep.DefaultMaxSweepAttempts, + MaxFeeRate: cfg.Sweeper.MaxFeeRate, + Aggregator: aggregator, }) s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 69f3403e23..c3ce504ec0 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" @@ -25,10 +26,6 @@ var ( // confirmed in a tx of the remote party. ErrRemoteSpend = errors.New("remote party swept utxo") - // ErrTooManyAttempts is returned in case sweeping an output has failed - // for the configured max number of attempts. - ErrTooManyAttempts = errors.New("sweep failed after max attempts") - // ErrFeePreferenceTooLow is returned when the fee preference gives a // fee rate that's below the relay fee rate. ErrFeePreferenceTooLow = errors.New("fee preference too low") @@ -88,12 +85,97 @@ func (p Params) String() string { p.Fee, p.Force) } +// SweepState represents the current state of a pending input. +// +//nolint:revive +type SweepState uint8 + +const ( + // StateInit is the initial state of a pending input. This is set when + // a new sweeping request for a given input is made. + StateInit SweepState = iota + + // StatePendingPublish specifies an input's state where it's already + // been included in a sweeping tx but the tx is not published yet. + // Inputs in this state should not be used for grouping again. + StatePendingPublish + + // StatePublished is the state where the input's sweeping tx has + // successfully been published. Inputs in this state can only be + // updated via RBF. + StatePublished + + // StatePublishFailed is the state when an error is returned from + // publishing the sweeping tx. Inputs in this state can be re-grouped + // in to a new sweeping tx. + StatePublishFailed + + // StateSwept is the final state of a pending input. This is set when + // the input has been successfully swept. + StateSwept + + // StateExcluded is the state of a pending input that has been excluded + // and can no longer be swept. For instance, when one of the three + // anchor sweeping transactions confirmed, the remaining two will be + // excluded. + StateExcluded + + // StateFailed is the state when a pending input has too many failed + // publish atttempts or unknown broadcast error is returned. + StateFailed +) + +// String gives a human readable text for the sweep states. +func (s SweepState) String() string { + switch s { + case StateInit: + return "Init" + + case StatePendingPublish: + return "PendingPublish" + + case StatePublished: + return "Published" + + case StatePublishFailed: + return "PublishFailed" + + case StateSwept: + return "Swept" + + case StateExcluded: + return "Excluded" + + case StateFailed: + return "Failed" + + default: + return "Unknown" + } +} + +// RBFInfo stores the information required to perform a RBF bump on a pending +// sweeping tx. +type RBFInfo struct { + // Txid is the txid of the sweeping tx. + Txid chainhash.Hash + + // FeeRate is the fee rate of the sweeping tx. + FeeRate chainfee.SatPerKWeight + + // Fee is the total fee of the sweeping tx. + Fee btcutil.Amount +} + // pendingInput is created when an input reaches the main loop for the first // time. It wraps the input and tracks all relevant state that is needed for // sweeping. type pendingInput struct { input.Input + // state tracks the current state of the input. + state SweepState + // listeners is a list of channels over which the final outcome of the // sweep needs to be broadcasted. listeners []chan Result @@ -102,10 +184,6 @@ type pendingInput struct { // notifier spend registration. ntfnRegCancel func() - // minPublishHeight indicates the minimum block height at which this - // input may be (re)published. - minPublishHeight int32 - // publishAttempts records the number of attempts that have already been // made to sweep this tx. publishAttempts int @@ -116,6 +194,9 @@ type pendingInput struct { // lastFeeRate is the most recent fee rate used for this input within a // transaction broadcast to the network. lastFeeRate chainfee.SatPerKWeight + + // rbf records the RBF constraints. + rbf fn.Option[RBFInfo] } // parameters returns the sweep parameters for this input. @@ -125,6 +206,21 @@ func (p *pendingInput) parameters() Params { return p.params } +// terminated returns a boolean indicating whether the input has reached a +// final state. +func (p *pendingInput) terminated() bool { + switch p.state { + // If the input has reached a final state, that it's either + // been swept, or failed, or excluded, we will remove it from + // our sweeper. + case StateFailed, StateSwept, StateExcluded: + return true + + default: + return false + } +} + // pendingInputs is a type alias for a set of pending inputs. type pendingInputs = map[wire.OutPoint]*pendingInput @@ -164,10 +260,6 @@ type PendingInput struct { // input. BroadcastAttempts int - // NextBroadcastHeight is the next height of the chain at which we'll - // attempt to broadcast a transaction sweeping the input. - NextBroadcastHeight uint32 - // Params contains the sweep parameters for this pending request. Params Params } @@ -210,8 +302,6 @@ type UtxoSweeper struct { // requested to sweep. pendingInputs pendingInputs - testSpendChan chan wire.OutPoint - currentOutputScript []byte relayFeeRate chainfee.SatPerKWeight @@ -248,6 +338,10 @@ type UtxoSweeperConfig struct { // certain on-chain events. Notifier chainntnfs.ChainNotifier + // Mempool is the mempool watcher that will be used to query whether a + // given input is already being spent by a transaction in the mempool. + Mempool chainntnfs.MempoolWatcher + // Store stores the published sweeper txes. Store SweeperStore @@ -265,12 +359,7 @@ type UtxoSweeperConfig struct { // to the caller. MaxSweepAttempts int - // NextAttemptDeltaFunc returns given the number of already attempted - // sweeps, how many blocks to wait before retrying to sweep. - NextAttemptDeltaFunc func(int) int32 - - // MaxFeeRate is the maximum fee rate allowed within the - // UtxoSweeper. + // MaxFeeRate is the maximum fee rate allowed within the UtxoSweeper. MaxFeeRate chainfee.SatPerVByte // Aggregator is used to group inputs into clusters based on its @@ -403,6 +492,8 @@ func (s *UtxoSweeper) Stop() error { // NOTE: Extreme care needs to be taken that input isn't changed externally. // Because it is an interface and we don't know what is exactly behind it, we // cannot make a local copy in sweeper. +// +// TODO(yy): make sure the caller is using the Result chan. func (s *UtxoSweeper) SweepInput(input input.Input, params Params) (chan Result, error) { @@ -547,6 +638,12 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { log.Debugf("Sweep ticker started") for { + // Clean inputs, which will remove inputs that are swept, + // failed, or excluded from the sweeper and return inputs that + // are either new or has been published but failed back, which + // will be retried again here. + inputs := s.updateSweeperInputs() + select { // A new inputs is offered to the sweeper. We check to see if // we are already trying to sweep this input and if not, set up @@ -575,10 +672,17 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // The timer expires and we are going to (re)sweep. case <-ticker.C: - log.Debugf("Sweep ticker ticks, attempt sweeping...") - s.handleSweep() + log.Debugf("Sweep ticker ticks, attempt sweeping %d "+ + "inputs", len(inputs)) + + // Sweep the remaining pending inputs. + s.sweepPendingInputs(inputs) // A new block comes in, update the bestHeight. + // + // TODO(yy): this is where we check our published transactions + // and perform RBF if needed. We'd also like to consult our fee + // bumper to get an updated fee rate. case epoch, ok := <-blockEpochs: if !ok { return @@ -614,11 +718,22 @@ func (s *UtxoSweeper) removeExclusiveGroup(group uint64) { continue } + // Skip inputs that are already terminated. + if input.terminated() { + log.Tracef("Skipped sending error result for "+ + "input %v, state=%v", outpoint, input.state) + + continue + } + // Signal result channels. - s.signalAndRemove(&outpoint, Result{ + s.signalResult(input, Result{ Err: ErrExclusiveGroupSpend, }) + // Update the input's state as it can no longer be swept. + input.state = StateExcluded + // Remove all unconfirmed transactions from the wallet which // spend the passed outpoint of the same exclusive group. outpoints := map[wire.OutPoint]struct{}{ @@ -669,7 +784,8 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster) error { // causing the failure and retry the rest of the // inputs. if errAllSets != nil { - log.Errorf("sweep all inputs: %w", err) + log.Errorf("Sweep all inputs got error: %v", + errAllSets) break } } @@ -680,10 +796,10 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster) error { return nil } - // We'd end up there if there's no retried inputs. In this - // case, we'd sweep the new input sets. If there's an error - // when sweeping a given set, we'd log the error and sweep the - // next set. + // We'd end up there if there's no retried inputs or the above + // sweeping tx failed. In this case, we'd sweep the new input + // sets. If there's an error when sweeping a given set, we'd + // log the error and sweep the next set. for _, inputs := range newSets { err := s.sweep(inputs, cluster.sweepFeeRate) if err != nil { @@ -695,21 +811,19 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster) error { }) } -// signalAndRemove notifies the listeners of the final result of the input -// sweep. It cancels any pending spend notification and removes the input from -// the list of pending inputs. When this function returns, the sweeper has -// completely forgotten about the input. -func (s *UtxoSweeper) signalAndRemove(outpoint *wire.OutPoint, result Result) { - pendInput := s.pendingInputs[*outpoint] - listeners := pendInput.listeners +// signalResult notifies the listeners of the final result of the input sweep. +// It also cancels any pending spend notification. +func (s *UtxoSweeper) signalResult(pi *pendingInput, result Result) { + op := pi.OutPoint() + listeners := pi.listeners if result.Err == nil { log.Debugf("Dispatching sweep success for %v to %v listeners", - outpoint, len(listeners), + op, len(listeners), ) } else { log.Debugf("Dispatching sweep error for %v to %v listeners: %v", - outpoint, len(listeners), result.Err, + op, len(listeners), result.Err, ) } @@ -721,14 +835,11 @@ func (s *UtxoSweeper) signalAndRemove(outpoint *wire.OutPoint, result Result) { // Cancel spend notification with chain notifier. This is not necessary // in case of a success, except for that a reorg could still happen. - if pendInput.ntfnRegCancel != nil { - log.Debugf("Canceling spend ntfn for %v", outpoint) + if pi.ntfnRegCancel != nil { + log.Debugf("Canceling spend ntfn for %v", op) - pendInput.ntfnRegCancel() + pi.ntfnRegCancel() } - - // Inputs are no longer pending after result has been sent. - delete(s.pendingInputs, *outpoint) } // getInputLists goes through the given inputs and constructs multiple distinct @@ -759,12 +870,6 @@ func (s *UtxoSweeper) getInputLists( // sweeper to avoid this. var newInputs, retryInputs []txInput for _, input := range cluster.inputs { - // Skip inputs that have a minimum publish height that is not - // yet reached. - if input.minPublishHeight > s.currentHeight { - continue - } - // Add input to the either one of the lists. if input.publishAttempts == 0 { newInputs = append(newInputs, input) @@ -836,20 +941,13 @@ func (s *UtxoSweeper) sweep(inputs inputSet, Fee: uint64(fee), } - // Add tx before publication, so that we will always know that a spend - // by this tx is ours. Otherwise if the publish doesn't return, but did - // publish, we loose track of this tx. Even republication on startup - // doesn't prevent this, because that call returns a double spend error - // then and would also not add the hash to the store. - err = s.cfg.Store.StoreTx(tr) - if err != nil { - return fmt.Errorf("store tx: %w", err) - } - // Reschedule the inputs that we just tried to sweep. This is done in // case the following publish fails, we'd like to update the inputs' // publish attempts and rescue them in the next sweep. - s.rescheduleInputs(tx.TxIn) + err = s.markInputsPendingPublish(tr, tx.TxIn) + if err != nil { + return err + } log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", tx.TxHash(), len(tx.TxIn), s.currentHeight) @@ -859,17 +957,16 @@ func (s *UtxoSweeper) sweep(inputs inputSet, tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), ) if err != nil { + // TODO(yy): find out which input is causing the failure. + s.markInputsPublishFailed(tx.TxIn) + return err } - // Mark this tx in db once successfully published. - // - // NOTE: this will behave as an overwrite, which is fine as the record - // is small. - tr.Published = true - err = s.cfg.Store.StoreTx(tr) + // Inputs have been successfully published so we update their states. + err = s.markInputsPublished(tr, tx.TxIn) if err != nil { - return fmt.Errorf("store tx: %w", err) + return err } // If there's no error, remove the output script. Otherwise keep it so @@ -880,52 +977,139 @@ func (s *UtxoSweeper) sweep(inputs inputSet, return nil } -// rescheduleInputs updates the pending inputs with the given tx inputs. It -// increments the `publishAttempts` and calculates the next broadcast height -// for each input. When the publishAttempts exceeds MaxSweepAttemps(10), this -// input will be removed. -func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn) { +// markInputsPendingPublish saves the sweeping tx to db and updates the pending +// inputs with the given tx inputs. It also increments the `publishAttempts`. +func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, + inputs []*wire.TxIn) error { + + // Add tx to db before publication, so that we will always know that a + // spend by this tx is ours. Otherwise if the publish doesn't return, + // but did publish, we'd lose track of this tx. Even republication on + // startup doesn't prevent this, because that call returns a double + // spend error then and would also not add the hash to the store. + err := s.cfg.Store.StoreTx(tr) + if err != nil { + return fmt.Errorf("store tx: %w", err) + } + // Reschedule sweep. for _, input := range inputs { pi, ok := s.pendingInputs[input.PreviousOutPoint] if !ok { - // It can be that the input has been removed because it - // exceed the maximum number of attempts in a previous - // input set. It could also be that this input is an - // additional wallet input that was attached. In that - // case there also isn't a pending input to update. + // It could be that this input is an additional wallet + // input that was attached. In that case there also + // isn't a pending input to update. + log.Debugf("Skipped marking input as pending "+ + "published: %v not found in pending inputs", + input.PreviousOutPoint) + continue } + // If this input has already terminated, there's clearly + // something wrong as it would have been removed. In this case + // we log an error and skip marking this input as pending + // publish. + if pi.terminated() { + log.Errorf("Expect input %v to not have terminated "+ + "state, instead it has %v", + input.PreviousOutPoint, pi.state) + + continue + } + + // Update the input's state. + pi.state = StatePendingPublish + + // Record the fees and fee rate of this tx to prepare possible + // RBF. + pi.rbf = fn.Some(RBFInfo{ + Txid: tr.Txid, + FeeRate: chainfee.SatPerKWeight(tr.FeeRate), + Fee: btcutil.Amount(tr.Fee), + }) + // Record another publish attempt. pi.publishAttempts++ + } - // We don't care what the result of the publish call was. Even - // if it is published successfully, it can still be that it - // needs to be retried. Call NextAttemptDeltaFunc to calculate - // when to resweep this input. - nextAttemptDelta := s.cfg.NextAttemptDeltaFunc( - pi.publishAttempts, - ) + return nil +} + +// markInputsPublished updates the sweeping tx in db and marks the list of +// inputs as published. +func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, + inputs []*wire.TxIn) error { + + // Mark this tx in db once successfully published. + // + // NOTE: this will behave as an overwrite, which is fine as the record + // is small. + tr.Published = true + err := s.cfg.Store.StoreTx(tr) + if err != nil { + return fmt.Errorf("store tx: %w", err) + } + + // Reschedule sweep. + for _, input := range inputs { + pi, ok := s.pendingInputs[input.PreviousOutPoint] + if !ok { + // It could be that this input is an additional wallet + // input that was attached. In that case there also + // isn't a pending input to update. + log.Debugf("Skipped marking input as published: %v "+ + "not found in pending inputs", + input.PreviousOutPoint) + + continue + } - pi.minPublishHeight = s.currentHeight + nextAttemptDelta + // Valdiate that the input is in an expected state. + if pi.state != StatePendingPublish { + log.Errorf("Expect input %v to have %v, instead it "+ + "has %v", input.PreviousOutPoint, + StatePendingPublish, pi.state) + + continue + } + + // Update the input's state. + pi.state = StatePublished + } - log.Debugf("Rescheduling input %v after %v attempts at "+ - "height %v (delta %v)", input.PreviousOutPoint, - pi.publishAttempts, pi.minPublishHeight, - nextAttemptDelta) + return nil +} + +// markInputsPublishFailed marks the list of inputs as failed to be published. +func (s *UtxoSweeper) markInputsPublishFailed(inputs []*wire.TxIn) { + // Reschedule sweep. + for _, input := range inputs { + pi, ok := s.pendingInputs[input.PreviousOutPoint] + if !ok { + // It could be that this input is an additional wallet + // input that was attached. In that case there also + // isn't a pending input to update. + log.Debugf("Skipped marking input as publish failed: "+ + "%v not found in pending inputs", + input.PreviousOutPoint) + + continue + } - if pi.publishAttempts >= s.cfg.MaxSweepAttempts { - log.Warnf("input %v: publishAttempts(%v) exceeds "+ - "MaxSweepAttempts(%v), removed", - input.PreviousOutPoint, pi.publishAttempts, - s.cfg.MaxSweepAttempts) + // Valdiate that the input is in an expected state. + if pi.state != StatePendingPublish { + log.Errorf("Expect input %v to have %v, instead it "+ + "has %v", input.PreviousOutPoint, + StatePendingPublish, pi.state) - // Signal result channels sweep result. - s.signalAndRemove(&input.PreviousOutPoint, Result{ - Err: ErrTooManyAttempts, - }) + continue } + + log.Warnf("Failed to publish input %v", input.PreviousOutPoint) + + // Update the input's state. + pi.state = StatePublishFailed } } @@ -956,8 +1140,8 @@ func (s *UtxoSweeper) monitorSpend(outpoint wire.OutPoint, return } - log.Debugf("Delivering spend ntfn for %v", - outpoint) + log.Debugf("Delivering spend ntfn for %v", outpoint) + select { case s.spendChan <- spend: log.Debugf("Delivered spend ntfn for %v", @@ -1012,10 +1196,9 @@ func (s *UtxoSweeper) handlePendingSweepsReq( Amount: btcutil.Amount( pendingInput.SignDesc().Output.Value, ), - LastFeeRate: pendingInput.lastFeeRate, - BroadcastAttempts: pendingInput.publishAttempts, - NextBroadcastHeight: uint32(pendingInput.minPublishHeight), - Params: pendingInput.params, + LastFeeRate: pendingInput.lastFeeRate, + BroadcastAttempts: pendingInput.publishAttempts, + Params: pendingInput.params, } } @@ -1091,22 +1274,16 @@ func (s *UtxoSweeper) handleUpdateReq(req *updateReq) ( newParams.Fee = req.params.Fee newParams.Force = req.params.Force - log.Debugf("Updating sweep parameters for %v from %v to %v", req.input, - pendingInput.params, newParams) + log.Debugf("Updating parameters for %v(state=%v) from (%v) to (%v)", + req.input, pendingInput.state, pendingInput.params, newParams) pendingInput.params = newParams - // We'll reset the input's publish height to the current so that a new - // transaction can be created that replaces the transaction currently - // spending the input. We only do this for inputs that have been - // broadcast at least once to ensure we don't spend an input before its - // maturity height. + // We need to reset the state so this input will be attempted again by + // our sweeper. // - // NOTE: The UtxoSweeper is not yet offered time-locked inputs, so the - // check for broadcast attempts is redundant at the moment. - if pendingInput.publishAttempts > 0 { - pendingInput.minPublishHeight = s.currentHeight - } + // TODO(yy): a dedicated state? + pendingInput.state = StateInit resultChan := make(chan Result, 1) pendingInput.listeners = append(pendingInput.listeners, resultChan) @@ -1153,43 +1330,60 @@ func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, return tx, err } -// DefaultNextAttemptDeltaFunc is the default calculation for next sweep attempt -// scheduling. It implements exponential back-off with some randomness. This is -// to prevent a stuck tx (for example because fee is too low and can't be bumped -// in btcd) from blocking all other retried inputs in the same tx. -func DefaultNextAttemptDeltaFunc(attempts int) int32 { - return 1 + rand.Int31n(1< inputClusters[j].sweepFeeRate diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index c12b04aae5..0168d9f08f 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1,6 +1,7 @@ package sweep import ( + "errors" "os" "runtime/pprof" "testing" @@ -12,10 +13,12 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lntest/mock" + lnmock "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/require" @@ -135,7 +138,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { Wallet: backend, TickerDuration: 100 * time.Millisecond, Store: store, - Signer: &mock.DummySigner{}, + Signer: &lnmock.DummySigner{}, GenSweepScript: func() ([]byte, error) { script := make([]byte, input.P2WPKHSize) script[0] = 0 @@ -145,12 +148,8 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { FeeEstimator: estimator, MaxInputsPerTx: testMaxInputsPerTx, MaxSweepAttempts: testMaxSweepAttempts, - NextAttemptDeltaFunc: func(attempts int) int32 { - // Use delta func without random factor. - return 1 << uint(attempts-1) - }, - MaxFeeRate: DefaultMaxFeeRate, - Aggregator: aggregator, + MaxFeeRate: DefaultMaxFeeRate, + Aggregator: aggregator, }) ctx.sweeper.Start() @@ -877,41 +876,6 @@ func TestRetry(t *testing.T) { ctx.finish(1) } -// TestGiveUp asserts that the sweeper gives up on an input if it can't be swept -// after a configured number of attempts.a -func TestGiveUp(t *testing.T) { - ctx := createSweeperTestContext(t) - - resultChan0, err := ctx.sweeper.SweepInput( - spendableInputs[0], defaultFeePref, - ) - if err != nil { - t.Fatal(err) - } - - // We expect a sweep to be published at height 100 (mockChainIOHeight). - ctx.receiveTx() - - // Because of MaxSweepAttemps, two more sweeps will be attempted. We - // configured exponential back-off without randomness for the test. The - // second attempt, we expect to happen at 101. The third attempt at 103. - // At that point, the input is expected to be failed. - - // Second attempt - ctx.notifier.NotifyEpoch(101) - ctx.receiveTx() - - // Third attempt - ctx.notifier.NotifyEpoch(103) - ctx.receiveTx() - - ctx.expectResult(resultChan0, ErrTooManyAttempts) - - ctx.backend.mine() - - ctx.finish(1) -} - // TestDifferentFeePreferences ensures that the sweeper can have different // transactions for different fee preferences. These transactions should be // broadcast from highest to lowest fee rate. @@ -1026,24 +990,14 @@ func TestPendingInputs(t *testing.T) { // We should expect to see all inputs pending. ctx.assertPendingInputs(input1, input2, input3) - // We should expect to see both sweep transactions broadcast. The higher - // fee rate sweep should be broadcast first. We'll remove the lower fee - // rate sweep to ensure we can detect pending inputs after a sweep. - // Once the higher fee rate sweep confirms, we should no longer see - // those inputs pending. + // We should expect to see both sweep transactions broadcast - one for + // the higher feerate, the other for the lower. ctx.receiveTx() - lowFeeRateTx := ctx.receiveTx() - ctx.backend.deleteUnconfirmed(lowFeeRateTx.TxHash()) - ctx.backend.mine() - ctx.expectResult(resultChan1, nil) - ctx.assertPendingInputs(input3) - - // We'll then trigger a new block to rebroadcast the lower fee rate - // sweep. Once again we'll ensure those inputs are no longer pending - // once the sweep transaction confirms. - ctx.backend.notifier.NotifyEpoch(101) ctx.receiveTx() + + // Mine these txns, and we should expect to see the results delivered. ctx.backend.mine() + ctx.expectResult(resultChan1, nil) ctx.expectResult(resultChan3, nil) ctx.assertPendingInputs() @@ -2025,3 +1979,527 @@ func TestGetInputLists(t *testing.T) { }) } } + +// TestMarkInputsPendingPublish checks that given a list of inputs with +// different states, only the non-terminal state will be marked as `Published`. +func TestMarkInputsPendingPublish(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock sweeper store. + mockStore := NewMockSweeperStore() + + // Create a test TxRecord and a dummy error. + dummyTR := &TxRecord{} + dummyErr := errors.New("dummy error") + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: mockStore, + }) + + // Create three testing inputs. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + state: StateInit, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + state: StatePendingPublish, + } + + // inputTerminated specifies an input that's terminated. + inputTerminated := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 4}, + } + s.pendingInputs[inputTerminated.PreviousOutPoint] = &pendingInput{ + state: StateExcluded, + } + + // First, check that when an error is returned from db, it's properly + // returned here. + mockStore.On("StoreTx", dummyTR).Return(dummyErr).Once() + err := s.markInputsPendingPublish(dummyTR, nil) + require.ErrorIs(err, dummyErr) + + // Then, check that the target input has will be correctly marked as + // published. + // + // Mock the store to return nil + mockStore.On("StoreTx", dummyTR).Return(nil).Once() + + // Mark the test inputs. We expect the non-exist input and the + // inputTerminated to be skipped, and the rest to be marked as pending + // publish. + err = s.markInputsPendingPublish(dummyTR, []*wire.TxIn{ + inputNotExist, inputInit, inputPendingPublish, inputTerminated, + }) + require.NoError(err) + + // We expect unchanged number of pending inputs. + require.Len(s.pendingInputs, 3) + + // We expect the init input's state to become pending publish. + require.Equal(StatePendingPublish, + s.pendingInputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish to stay unchanged. + require.Equal(StatePendingPublish, + s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + + // We expect the terminated to stay unchanged. + require.Equal(StateExcluded, + s.pendingInputs[inputTerminated.PreviousOutPoint].state) + + // Assert mocked statements are executed as expected. + mockStore.AssertExpectations(t) +} + +// TestMarkInputsPublished checks that given a list of inputs with different +// states, only the state `StatePendingPublish` will be marked as `Published`. +func TestMarkInputsPublished(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock sweeper store. + mockStore := NewMockSweeperStore() + + // Create a test TxRecord and a dummy error. + dummyTR := &TxRecord{} + dummyErr := errors.New("dummy error") + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: mockStore, + }) + + // Create three testing inputs. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + state: StateInit, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + state: StatePendingPublish, + } + + // First, check that when an error is returned from db, it's properly + // returned here. + mockStore.On("StoreTx", dummyTR).Return(dummyErr).Once() + err := s.markInputsPublished(dummyTR, nil) + require.ErrorIs(err, dummyErr) + + // We also expect the record has been marked as published. + require.True(dummyTR.Published) + + // Then, check that the target input has will be correctly marked as + // published. + // + // Mock the store to return nil + mockStore.On("StoreTx", dummyTR).Return(nil).Once() + + // Mark the test inputs. We expect the non-exist input and the + // inputInit to be skipped, and the final input to be marked as + // published. + err = s.markInputsPublished(dummyTR, []*wire.TxIn{ + inputNotExist, inputInit, inputPendingPublish, + }) + require.NoError(err) + + // We expect unchanged number of pending inputs. + require.Len(s.pendingInputs, 2) + + // We expect the init input's state to stay unchanged. + require.Equal(StateInit, + s.pendingInputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish input's is now marked as published. + require.Equal(StatePublished, + s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + + // Assert mocked statements are executed as expected. + mockStore.AssertExpectations(t) +} + +// TestMarkInputsPublishFailed checks that given a list of inputs with +// different states, only the state `StatePendingPublish` will be marked as +// `PublishFailed`. +func TestMarkInputsPublishFailed(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock sweeper store. + mockStore := NewMockSweeperStore() + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: mockStore, + }) + + // Create three testing inputs. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + state: StateInit, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + state: StatePendingPublish, + } + + // Mark the test inputs. We expect the non-exist input and the + // inputInit to be skipped, and the final input to be marked as + // published. + s.markInputsPublishFailed([]*wire.TxIn{ + inputNotExist, inputInit, inputPendingPublish, + }) + + // We expect unchanged number of pending inputs. + require.Len(s.pendingInputs, 2) + + // We expect the init input's state to stay unchanged. + require.Equal(StateInit, + s.pendingInputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish input's is now marked as publish + // failed. + require.Equal(StatePublishFailed, + s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + + // Assert mocked statements are executed as expected. + mockStore.AssertExpectations(t) +} + +// TestMarkInputsSwept checks that given a list of inputs with different +// states, only the non-terminal state will be marked as `StateSwept`. +func TestMarkInputsSwept(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock input. + mockInput := &input.MockInput{} + defer mockInput.AssertExpectations(t) + + // Mock the `OutPoint` to return a dummy outpoint. + mockInput.On("OutPoint").Return(&wire.OutPoint{Hash: chainhash.Hash{1}}) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + // Create three testing inputs. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + state: StateInit, + Input: mockInput, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + state: StatePendingPublish, + Input: mockInput, + } + + // inputTerminated specifies an input that's terminated. + inputTerminated := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 4}, + } + s.pendingInputs[inputTerminated.PreviousOutPoint] = &pendingInput{ + state: StateExcluded, + Input: mockInput, + } + + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + inputNotExist, inputInit, + inputPendingPublish, inputTerminated, + }, + } + + // Mark the test inputs. We expect the inputTerminated to be skipped, + // and the rest to be marked as swept. + s.markInputsSwept(tx, true) + + // We expect unchanged number of pending inputs. + require.Len(s.pendingInputs, 3) + + // We expect the init input's state to become swept. + require.Equal(StateSwept, + s.pendingInputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish becomes swept. + require.Equal(StateSwept, + s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + + // We expect the terminated to stay unchanged. + require.Equal(StateExcluded, + s.pendingInputs[inputTerminated.PreviousOutPoint].state) +} + +// TestMempoolLookup checks that the method `mempoolLookup` works as expected. +func TestMempoolLookup(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a test outpoint. + op := wire.OutPoint{Index: 1} + + // Create a mock mempool watcher. + mockMempool := chainntnfs.NewMockMempoolWatcher() + defer mockMempool.AssertExpectations(t) + + // Create a test sweeper without a mempool. + s := New(&UtxoSweeperConfig{}) + + // Since we don't have a mempool, we expect the call to return a + // fn.None indicating it's not found. + tx := s.mempoolLookup(op) + require.True(tx.IsNone()) + + // Re-create the sweeper with the mocked mempool watcher. + s = New(&UtxoSweeperConfig{ + Mempool: mockMempool, + }) + + // Mock the mempool watcher to return not found. + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.None[wire.MsgTx]()).Once() + + // We expect a fn.None tx to be returned. + tx = s.mempoolLookup(op) + require.True(tx.IsNone()) + + // Mock the mempool to return a spending tx. + dummyTx := wire.MsgTx{} + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.Some(dummyTx)).Once() + + // Calling the loopup again, we expect the dummyTx to be returned. + tx = s.mempoolLookup(op) + require.False(tx.IsNone()) + require.Equal(dummyTx, tx.UnsafeFromSome()) +} + +// TestUpdateSweeperInputs checks that the method `updateSweeperInputs` will +// properly update the inputs based on their states. +func TestUpdateSweeperInputs(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a test sweeper. + s := New(nil) + + // Create a list of inputs using all the states. + input0 := &pendingInput{state: StateInit} + input1 := &pendingInput{state: StatePendingPublish} + input2 := &pendingInput{state: StatePublished} + input3 := &pendingInput{state: StatePublishFailed} + input4 := &pendingInput{state: StateSwept} + input5 := &pendingInput{state: StateExcluded} + input6 := &pendingInput{state: StateFailed} + + // Add the inputs to the sweeper. After the update, we should see the + // terminated inputs being removed. + s.pendingInputs = map[wire.OutPoint]*pendingInput{ + {Index: 0}: input0, + {Index: 1}: input1, + {Index: 2}: input2, + {Index: 3}: input3, + {Index: 4}: input4, + {Index: 5}: input5, + {Index: 6}: input6, + } + + // We expect the inputs with `StateSwept`, `StateExcluded`, and + // `StateFailed` to be removed. + expectedInputs := map[wire.OutPoint]*pendingInput{ + {Index: 0}: input0, + {Index: 1}: input1, + {Index: 2}: input2, + {Index: 3}: input3, + } + + // We expect only the inputs with `StateInit` and `StatePublishFailed` + // to be returned. + expectedReturn := map[wire.OutPoint]*pendingInput{ + {Index: 0}: input0, + {Index: 3}: input3, + } + + // Update the sweeper inputs. + inputs := s.updateSweeperInputs() + + // Assert the returned inputs are as expected. + require.Equal(expectedReturn, inputs) + + // Assert the sweeper inputs are as expected. + require.Equal(expectedInputs, s.pendingInputs) +} + +// TestDecideStateAndRBFInfo checks that the expected state and RBFInfo are +// returned based on whether this input can be found both in mempool and the +// sweeper store. +func TestDecideStateAndRBFInfo(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a test outpoint. + op := wire.OutPoint{Index: 1} + + // Create a mock mempool watcher and a mock sweeper store. + mockMempool := chainntnfs.NewMockMempoolWatcher() + defer mockMempool.AssertExpectations(t) + mockStore := NewMockSweeperStore() + defer mockStore.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: mockStore, + Mempool: mockMempool, + }) + + // First, mock the mempool to return false. + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.None[wire.MsgTx]()).Once() + + // Since the mempool lookup failed, we exepect state Init and no + // RBFInfo. + state, rbf := s.decideStateAndRBFInfo(op) + require.True(rbf.IsNone()) + require.Equal(StateInit, state) + + // Mock the mempool lookup to return a tx three times as we are calling + // attachAvailableRBFInfo three times. + tx := wire.MsgTx{} + mockMempool.On("LookupInputMempoolSpend", op).Return( + fn.Some(tx)).Times(3) + + // Mock the store to return an error saying the tx cannot be found. + mockStore.On("GetTx", tx.TxHash()).Return(nil, ErrTxNotFound).Once() + + // Although the db lookup failed, we expect the state to be Published. + state, rbf = s.decideStateAndRBFInfo(op) + require.True(rbf.IsNone()) + require.Equal(StatePublished, state) + + // Mock the store to return a db error. + dummyErr := errors.New("dummy error") + mockStore.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once() + + // Although the db lookup failed, we expect the state to be Published. + state, rbf = s.decideStateAndRBFInfo(op) + require.True(rbf.IsNone()) + require.Equal(StatePublished, state) + + // Mock the store to return a record. + tr := &TxRecord{ + Fee: 100, + FeeRate: 100, + } + mockStore.On("GetTx", tx.TxHash()).Return(tr, nil).Once() + + // Call the method again. + state, rbf = s.decideStateAndRBFInfo(op) + + // Assert that the RBF info is returned. + rbfInfo := fn.Some(RBFInfo{ + Txid: tx.TxHash(), + Fee: btcutil.Amount(tr.Fee), + FeeRate: chainfee.SatPerKWeight(tr.FeeRate), + }) + require.Equal(rbfInfo, rbf) + + // Assert the state is updated. + require.Equal(StatePublished, state) +} + +// TestMarkInputFailed checks that the input is marked as failed as expected. +func TestMarkInputFailed(t *testing.T) { + t.Parallel() + + // Create a mock input. + mockInput := &input.MockInput{} + defer mockInput.AssertExpectations(t) + + // Mock the `OutPoint` to return a dummy outpoint. + mockInput.On("OutPoint").Return(&wire.OutPoint{Hash: chainhash.Hash{1}}) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + // Create a testing pending input. + pi := &pendingInput{ + state: StateInit, + Input: mockInput, + } + + // Call the method under test. + s.markInputFailed(pi, errors.New("dummy error")) + + // Assert the state is updated. + require.Equal(t, StateFailed, pi.state) +} diff --git a/sweep/test_utils.go b/sweep/test_utils.go index 86dfd6d2b8..e36b56a6b8 100644 --- a/sweep/test_utils.go +++ b/sweep/test_utils.go @@ -99,6 +99,8 @@ func (m *MockNotifier) sendSpend(channel chan *chainntnfs.SpendDetail, outpoint *wire.OutPoint, spendingTx *wire.MsgTx) { + log.Debugf("Notifying spend of outpoint %v", outpoint) + spenderTxHash := spendingTx.TxHash() channel <- &chainntnfs.SpendDetail{ SpenderTxHash: &spenderTxHash, @@ -188,6 +190,8 @@ func (m *MockNotifier) Stop() error { func (m *MockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint, _ []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) { + log.Debugf("RegisterSpendNtfn for outpoint %v", outpoint) + // Add channel to global spend ntfn map. m.mutex.Lock()