From 3f8da351e04885774322789264925ffb3c058abc Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 17 Jan 2024 00:25:41 +0800 Subject: [PATCH 01/19] sweep: expand `InputSet` with more interface methods This commit adds more interface methods to `InputSet` to prepare the addition of budget-based aggregator. --- sweep/mock_test.go | 16 ++++++++++++++++ sweep/tx_input_set.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/sweep/mock_test.go b/sweep/mock_test.go index fc7ff9c34d..f908cf6dbf 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -5,8 +5,10 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -446,3 +448,17 @@ func (m *MockInputSet) NeedWalletInput() bool { return args.Bool(0) } + +// DeadlineHeight returns the deadline height for the set. +func (m *MockInputSet) DeadlineHeight() fn.Option[int32] { + args := m.Called() + + return args.Get(0).(fn.Option[int32]) +} + +// Budget givens the total amount that can be used as fees by this input set. +func (m *MockInputSet) Budget() btcutil.Amount { + args := m.Called() + + return args.Get(0).(btcutil.Amount) +} diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index b80ea2db09..789bb277b3 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -53,6 +54,24 @@ type InputSet interface { // NeedWalletInput returns true if the input set needs more wallet // inputs. NeedWalletInput() bool + + // DeadlineHeight returns an optional absolute block height to express + // the time-sensitivity of the input set. The outputs from a force + // close tx have different time preferences: + // - to_local: no time pressure as it can only be swept by us. + // - first level outgoing HTLC: must be swept before its corresponding + // incoming HTLC's CLTV is reached. + // - first level incoming HTLC: must be swept before its CLTV is + // reached. + // - second level HTLCs: no time pressure. + // - anchor: for CPFP-purpose anchor, it must be swept before any of + // the above CLTVs is reached. For non-CPFP purpose anchor, there's + // no time pressure. + DeadlineHeight() fn.Option[int32] + + // Budget givens the total amount that can be used as fees by this + // input set. + Budget() btcutil.Amount } type txInputSetState struct { @@ -167,6 +186,20 @@ func (t *txInputSet) Inputs() []input.Input { return t.inputs } +// Budget gives the total amount that can be used as fees by this input set. +// +// NOTE: this field is only used for `BudgetInputSet`. +func (t *txInputSet) Budget() btcutil.Amount { + return t.totalOutput() +} + +// DeadlineHeight gives the block height that this set must be confirmed by. +// +// NOTE: this field is only used for `BudgetInputSet`. +func (t *txInputSet) DeadlineHeight() fn.Option[int32] { + return fn.None[int32]() +} + // FeeRate returns the fee rate that should be used for the tx. func (t *txInputSet) FeeRate() chainfee.SatPerKWeight { return t.feeRate From d00c104e56421a299ff965c40d5f9c281eb30418 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 17 Jan 2024 00:20:24 +0800 Subject: [PATCH 02/19] sweep: change `markInputsPublishFailed` to take outpoints This way it's easier to pass values to this method in various callsites. --- sweep/sweeper.go | 21 ++++++++++++--------- sweep/sweeper_test.go | 6 ++++-- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index bd266aaa0b..12b3efc24f 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -807,8 +807,13 @@ func (s *UtxoSweeper) sweep(set InputSet) error { tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), ) if err != nil { + outpoints := make([]wire.OutPoint, len(set.Inputs())) + for i, inp := range set.Inputs() { + outpoints[i] = *inp.OutPoint() + } + // TODO(yy): find out which input is causing the failure. - s.markInputsPublishFailed(tx.TxIn) + s.markInputsPublishFailed(outpoints) return err } @@ -932,17 +937,16 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, } // markInputsPublishFailed marks the list of inputs as failed to be published. -func (s *UtxoSweeper) markInputsPublishFailed(inputs []*wire.TxIn) { +func (s *UtxoSweeper) markInputsPublishFailed(outpoints []wire.OutPoint) { // Reschedule sweep. - for _, input := range inputs { - pi, ok := s.pendingInputs[input.PreviousOutPoint] + for _, op := range outpoints { + pi, ok := s.pendingInputs[op] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Debugf("Skipped marking input as publish failed: "+ - "%v not found in pending inputs", - input.PreviousOutPoint) + "%v not found in pending inputs", op) continue } @@ -950,13 +954,12 @@ func (s *UtxoSweeper) markInputsPublishFailed(inputs []*wire.TxIn) { // Valdiate that the input is in an expected state. if pi.state != StatePendingPublish { log.Errorf("Expect input %v to have %v, instead it "+ - "has %v", input.PreviousOutPoint, - StatePendingPublish, pi.state) + "has %v", op, StatePendingPublish, pi.state) continue } - log.Warnf("Failed to publish input %v", input.PreviousOutPoint) + log.Warnf("Failed to publish input %v", op) // Update the input's state. pi.state = StatePublishFailed diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 519bbdbb2a..a870c2789a 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -2108,8 +2108,10 @@ func TestMarkInputsPublishFailed(t *testing.T) { // Mark the test inputs. We expect the non-exist input and the // inputInit to be skipped, and the final input to be marked as // published. - s.markInputsPublishFailed([]*wire.TxIn{ - inputNotExist, inputInit, inputPendingPublish, + s.markInputsPublishFailed([]wire.OutPoint{ + inputNotExist.PreviousOutPoint, + inputInit.PreviousOutPoint, + inputPendingPublish.PreviousOutPoint, }) // We expect unchanged number of pending inputs. From bd9256ecac5372d5bccb6a691bd99362a21a3d1d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 21 Feb 2024 16:00:05 +0800 Subject: [PATCH 03/19] sweep: refactor `markInputsPendingPublish` to take `InputSet` This commit changes `markInputsPendingPublish` to take `InputSet` only. This is needed for the following commits as we won't be able to know the tx being created beforehand, yet we still want to make sure these inputs won't be grouped to another input set as it complicates our RBF process. --- sweep/sweeper.go | 41 +++++------------------ sweep/sweeper_test.go | 77 ++++++++++++++++++------------------------- 2 files changed, 40 insertions(+), 78 deletions(-) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 12b3efc24f..92c6646d2d 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -794,10 +794,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // Reschedule the inputs that we just tried to sweep. This is done in // case the following publish fails, we'd like to update the inputs' // publish attempts and rescue them in the next sweep. - err = s.markInputsPendingPublish(tr, tx.TxIn) - if err != nil { - return err - } + s.markInputsPendingPublish(set) log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", tx.TxHash(), len(tx.TxIn), s.currentHeight) @@ -832,31 +829,19 @@ func (s *UtxoSweeper) sweep(set InputSet) error { return nil } -// markInputsPendingPublish saves the sweeping tx to db and updates the pending -// inputs with the given tx inputs. It also increments the `publishAttempts`. -func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, - inputs []*wire.TxIn) error { - - // Add tx to db before publication, so that we will always know that a - // spend by this tx is ours. Otherwise if the publish doesn't return, - // but did publish, we'd lose track of this tx. Even republication on - // startup doesn't prevent this, because that call returns a double - // spend error then and would also not add the hash to the store. - err := s.cfg.Store.StoreTx(tr) - if err != nil { - return fmt.Errorf("store tx: %w", err) - } - +// markInputsPendingPublish updates the pending inputs with the given tx +// inputs. It also increments the `publishAttempts`. +func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // Reschedule sweep. - for _, input := range inputs { - pi, ok := s.pendingInputs[input.PreviousOutPoint] + for _, input := range set.Inputs() { + pi, ok := s.pendingInputs[*input.OutPoint()] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Debugf("Skipped marking input as pending "+ "published: %v not found in pending inputs", - input.PreviousOutPoint) + input.OutPoint()) continue } @@ -868,7 +853,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, if pi.terminated() { log.Errorf("Expect input %v to not have terminated "+ "state, instead it has %v", - input.PreviousOutPoint, pi.state) + input.OutPoint, pi.state) continue } @@ -876,19 +861,9 @@ func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, // Update the input's state. pi.state = StatePendingPublish - // Record the fees and fee rate of this tx to prepare possible - // RBF. - pi.rbf = fn.Some(RBFInfo{ - Txid: tr.Txid, - FeeRate: chainfee.SatPerKWeight(tr.FeeRate), - Fee: btcutil.Amount(tr.Fee), - }) - // Record another publish attempt. pi.publishAttempts++ } - - return nil } // markInputsPublished updates the sweeping tx in db and marks the list of diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index a870c2789a..f6891db92a 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1897,87 +1897,74 @@ func TestMarkInputsPendingPublish(t *testing.T) { require := require.New(t) - // Create a mock sweeper store. - mockStore := NewMockSweeperStore() - - // Create a test TxRecord and a dummy error. - dummyTR := &TxRecord{} - dummyErr := errors.New("dummy error") - // Create a test sweeper. - s := New(&UtxoSweeperConfig{ - Store: mockStore, - }) + s := New(&UtxoSweeperConfig{}) + + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) // Create three testing inputs. // // inputNotExist specifies an input that's not found in the sweeper's // `pendingInputs` map. - inputNotExist := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 1}, - } + inputNotExist := &input.MockInput{} + defer inputNotExist.AssertExpectations(t) + + inputNotExist.On("OutPoint").Return(&wire.OutPoint{Index: 0}) // inputInit specifies a newly created input. - inputInit := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 2}, - } - s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + inputInit := &input.MockInput{} + defer inputInit.AssertExpectations(t) + + inputInit.On("OutPoint").Return(&wire.OutPoint{Index: 1}) + + s.pendingInputs[*inputInit.OutPoint()] = &pendingInput{ state: StateInit, } // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 3}, - } - s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + inputPendingPublish := &input.MockInput{} + defer inputPendingPublish.AssertExpectations(t) + + inputPendingPublish.On("OutPoint").Return(&wire.OutPoint{Index: 2}) + + s.pendingInputs[*inputPendingPublish.OutPoint()] = &pendingInput{ state: StatePendingPublish, } // inputTerminated specifies an input that's terminated. - inputTerminated := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 4}, - } - s.pendingInputs[inputTerminated.PreviousOutPoint] = &pendingInput{ - state: StateExcluded, - } + inputTerminated := &input.MockInput{} + defer inputTerminated.AssertExpectations(t) - // First, check that when an error is returned from db, it's properly - // returned here. - mockStore.On("StoreTx", dummyTR).Return(dummyErr).Once() - err := s.markInputsPendingPublish(dummyTR, nil) - require.ErrorIs(err, dummyErr) + inputTerminated.On("OutPoint").Return(&wire.OutPoint{Index: 3}) - // Then, check that the target input has will be correctly marked as - // published. - // - // Mock the store to return nil - mockStore.On("StoreTx", dummyTR).Return(nil).Once() + s.pendingInputs[*inputTerminated.OutPoint()] = &pendingInput{ + state: StateExcluded, + } // Mark the test inputs. We expect the non-exist input and the // inputTerminated to be skipped, and the rest to be marked as pending // publish. - err = s.markInputsPendingPublish(dummyTR, []*wire.TxIn{ + set.On("Inputs").Return([]input.Input{ inputNotExist, inputInit, inputPendingPublish, inputTerminated, }) - require.NoError(err) + s.markInputsPendingPublish(set) // We expect unchanged number of pending inputs. require.Len(s.pendingInputs, 3) // We expect the init input's state to become pending publish. require.Equal(StatePendingPublish, - s.pendingInputs[inputInit.PreviousOutPoint].state) + s.pendingInputs[*inputInit.OutPoint()].state) // We expect the pending-publish to stay unchanged. require.Equal(StatePendingPublish, - s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + s.pendingInputs[*inputPendingPublish.OutPoint()].state) // We expect the terminated to stay unchanged. require.Equal(StateExcluded, - s.pendingInputs[inputTerminated.PreviousOutPoint].state) - - // Assert mocked statements are executed as expected. - mockStore.AssertExpectations(t) + s.pendingInputs[*inputTerminated.OutPoint()].state) } // TestMarkInputsPublished checks that given a list of inputs with different From 0496d70ac464bd5699ce5334c4af6b3a9db2f944 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 27 Feb 2024 17:52:47 +0800 Subject: [PATCH 04/19] sweep: introduce `BudgetInputSet` to manage budget-based inputs This commit adds `BudgetInputSet` which implements `InputSet`. It handles the pending inputs based on the supplied budgets and will be used in the following commit. --- sweep/sweeper.go | 16 ++ sweep/tx_input_set.go | 274 +++++++++++++++++++++++ sweep/tx_input_set_test.go | 433 +++++++++++++++++++++++++++++++++++++ 3 files changed, 723 insertions(+) diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 92c6646d2d..659c1d3408 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -52,11 +52,22 @@ type Params struct { // Force indicates whether the input should be swept regardless of // whether it is economical to do so. + // + // TODO(yy): Remove this param once deadline based sweeping is in place. Force bool // ExclusiveGroup is an identifier that, if set, prevents other inputs // with the same identifier from being batched together. ExclusiveGroup *uint64 + + // DeadlineHeight specifies an absolute block height that this input + // should be confirmed by. This value is used by the fee bumper to + // decide its urgency and adjust its feerate used. + DeadlineHeight fn.Option[int32] + + // Budget specifies the maximum amount of satoshis that can be spent on + // fees for this sweep. + Budget btcutil.Amount } // ParamsUpdate contains a new set of parameters to update a pending sweep with. @@ -196,6 +207,11 @@ type pendingInput struct { rbf fn.Option[RBFInfo] } +// String returns a human readable interpretation of the pending input. +func (p *pendingInput) String() string { + return fmt.Sprintf("%v (%v)", p.Input.OutPoint(), p.Input.WitnessType()) +} + // parameters returns the sweep parameters for this input. // // NOTE: Part of the txInput interface. diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index 789bb277b3..0354c6c1d4 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -35,6 +35,14 @@ var ( // ErrNotEnoughInputs is returned when there are not enough wallet // inputs to construct a non-dust change output for an input set. ErrNotEnoughInputs = fmt.Errorf("not enough inputs") + + // ErrDeadlinesMismatch is returned when the deadlines of the input + // sets do not match. + ErrDeadlinesMismatch = fmt.Errorf("deadlines mismatch") + + // ErrDustOutput is returned when the output value is below the dust + // limit. + ErrDustOutput = fmt.Errorf("dust output") ) // InputSet defines an interface that's responsible for filtering a set of @@ -542,3 +550,269 @@ func createWalletTxInput(utxo *lnwallet.Utxo) (input.Input, error) { &utxo.OutPoint, witnessType, signDesc, heightHint, ), nil } + +// BudgetInputSet implements the interface `InputSet`. It takes a list of +// pending inputs which share the same deadline height and groups them into a +// set conditionally based on their economical values. +type BudgetInputSet struct { + // inputs is the set of inputs that have been added to the set after + // considering their economical contribution. + inputs []*pendingInput + + // deadlineHeight is the height which the inputs in this set must be + // confirmed by. + deadlineHeight fn.Option[int32] +} + +// Compile-time constraint to ensure budgetInputSet implements InputSet. +var _ InputSet = (*BudgetInputSet)(nil) + +// validateInputs is used when creating new BudgetInputSet to ensure there are +// no duplicate inputs and they all share the same deadline heights, if set. +func validateInputs(inputs []pendingInput) error { + // Sanity check the input slice to ensure it's non-empty. + if len(inputs) == 0 { + return fmt.Errorf("inputs slice is empty") + } + + // dedupInputs is a map used to track unique outpoints of the inputs. + dedupInputs := make(map[*wire.OutPoint]struct{}) + + // deadlineSet stores unique deadline heights. + deadlineSet := make(map[fn.Option[int32]]struct{}) + + for _, input := range inputs { + input.params.DeadlineHeight.WhenSome(func(h int32) { + deadlineSet[input.params.DeadlineHeight] = struct{}{} + }) + + dedupInputs[input.OutPoint()] = struct{}{} + } + + // Make sure the inputs share the same deadline height when there is + // one. + if len(deadlineSet) > 1 { + return fmt.Errorf("inputs have different deadline heights") + } + + // Provide a defensive check to ensure that we don't have any duplicate + // inputs within the set. + if len(dedupInputs) != len(inputs) { + return fmt.Errorf("duplicate inputs") + } + + return nil +} + +// NewBudgetInputSet creates a new BudgetInputSet. +func NewBudgetInputSet(inputs []pendingInput) (*BudgetInputSet, error) { + // Validate the supplied inputs. + if err := validateInputs(inputs); err != nil { + return nil, err + } + + // TODO(yy): all the inputs share the same deadline height, which means + // there exists an opportunity to refactor the deadline height to be + // tracked on the set-level, not per input. This would allow us to + // avoid the overhead of tracking the same height for each input in the + // set. + deadlineHeight := inputs[0].params.DeadlineHeight + bi := &BudgetInputSet{ + deadlineHeight: deadlineHeight, + inputs: make([]*pendingInput, 0, len(inputs)), + } + + for _, input := range inputs { + bi.addInput(input) + } + + log.Tracef("Created %v", bi.String()) + + return bi, nil +} + +// String returns a human-readable description of the input set. +func (b *BudgetInputSet) String() string { + deadlineDesc := "none" + b.deadlineHeight.WhenSome(func(h int32) { + deadlineDesc = fmt.Sprintf("%d", h) + }) + + inputsDesc := "" + for _, input := range b.inputs { + inputsDesc += fmt.Sprintf("\n%v", input) + } + + return fmt.Sprintf("BudgetInputSet(budget=%v, deadline=%v, "+ + "inputs=[%v])", b.Budget(), deadlineDesc, inputsDesc) +} + +// addInput adds an input to the input set. +func (b *BudgetInputSet) addInput(input pendingInput) { + b.inputs = append(b.inputs, &input) +} + +// NeedWalletInput returns true if the input set needs more wallet inputs. +// +// A set may need wallet inputs when it has a required output or its total +// value cannot cover its total budget. +func (b *BudgetInputSet) NeedWalletInput() bool { + var ( + // budgetNeeded is the amount that needs to be covered from + // other inputs. + budgetNeeded btcutil.Amount + + // budgetBorrowable is the amount that can be borrowed from + // other inputs. + budgetBorrowable btcutil.Amount + ) + + for _, inp := range b.inputs { + // If this input has a required output, we can assume it's a + // second-level htlc txns input. Although this input must have + // a value that can cover its budget, it cannot be used to pay + // fees. Instead, we need to borrow budget from other inputs to + // make the sweep happen. Once swept, the input value will be + // credited to the wallet. + if inp.RequiredTxOut() != nil { + budgetNeeded += inp.params.Budget + continue + } + + // Get the amount left after covering the input's own budget. + // This amount can then be lent to the above input. + budget := inp.params.Budget + output := btcutil.Amount(inp.SignDesc().Output.Value) + budgetBorrowable += output - budget + + // If the input's budget is not even covered by itself, we need + // to borrow outputs from other inputs. + if budgetBorrowable < 0 { + log.Debugf("Input %v specified a budget that exceeds "+ + "its output value: %v > %v", inp, budget, + output) + } + } + + log.Tracef("NeedWalletInput: budgetNeeded=%v, budgetBorrowable=%v", + budgetNeeded, budgetBorrowable) + + // If we don't have enough extra budget to borrow, we need wallet + // inputs. + return budgetBorrowable < budgetNeeded +} + +// copyInputs returns a copy of the slice of the inputs in the set. +func (b *BudgetInputSet) copyInputs() []*pendingInput { + inputs := make([]*pendingInput, len(b.inputs)) + copy(inputs, b.inputs) + return inputs +} + +// AddWalletInputs adds wallet inputs to the set until the specified budget is +// met. When sweeping inputs with required outputs, although there's budget +// specified, it cannot be directly spent from these required outputs. Instead, +// we need to borrow budget from other inputs to make the sweep happen. +// There are two sources to borrow from: 1) other inputs, 2) wallet utxos. If +// we are calling this method, it means other inputs cannot cover the specified +// budget, so we need to borrow from wallet utxos. +// +// Return an error if there are not enough wallet inputs, and the budget set is +// set to its initial state by removing any wallet inputs added. +// +// NOTE: must be called with the wallet lock held via `WithCoinSelectLock`. +func (b *BudgetInputSet) AddWalletInputs(wallet Wallet) error { + // Retrieve wallet utxos. Only consider confirmed utxos to prevent + // problems around RBF rules for unconfirmed inputs. This currently + // ignores the configured coin selection strategy. + utxos, err := wallet.ListUnspentWitnessFromDefaultAccount( + 1, math.MaxInt32, + ) + if err != nil { + return fmt.Errorf("list unspent witness: %w", err) + } + + // Sort the UTXOs by putting smaller values at the start of the slice + // to avoid locking large UTXO for sweeping. + // + // TODO(yy): add more choices to CoinSelectionStrategy and use the + // configured value here. + sort.Slice(utxos, func(i, j int) bool { + return utxos[i].Value < utxos[j].Value + }) + + // Make a copy of the current inputs. If the wallet doesn't have enough + // utxos to cover the budget, we will revert the current set to its + // original state by removing the added wallet inputs. + originalInputs := b.copyInputs() + + // Add wallet inputs to the set until the specified budget is covered. + for _, utxo := range utxos { + input, err := createWalletTxInput(utxo) + if err != nil { + return err + } + + pi := pendingInput{ + Input: input, + params: Params{ + // Inherit the deadline height from the input + // set. + DeadlineHeight: b.deadlineHeight, + }, + } + + b.addInput(pi) + + // Return if we've reached the minimum output amount. + if !b.NeedWalletInput() { + return nil + } + } + + // The wallet doesn't have enough utxos to cover the budget. Revert the + // input set to its original state. + b.inputs = originalInputs + + return ErrNotEnoughInputs +} + +// Budget returns the total budget of the set. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) Budget() btcutil.Amount { + budget := btcutil.Amount(0) + for _, input := range b.inputs { + budget += input.params.Budget + } + + return budget +} + +// DeadlineHeight returns the deadline height of the set. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) DeadlineHeight() fn.Option[int32] { + return b.deadlineHeight +} + +// Inputs returns the inputs that should be used to create a tx. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) Inputs() []input.Input { + inputs := make([]input.Input, 0, len(b.inputs)) + for _, inp := range b.inputs { + inputs = append(inputs, inp.Input) + } + + return inputs +} + +// FeeRate returns the fee rate that should be used for the tx. +// +// NOTE: part of the InputSet interface. +// +// TODO(yy): will be removed once fee bumper is implemented. +func (b *BudgetInputSet) FeeRate() chainfee.SatPerKWeight { + return 0 +} diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go index 51afff7b77..32a08fba4d 100644 --- a/sweep/tx_input_set_test.go +++ b/sweep/tx_input_set_test.go @@ -1,10 +1,14 @@ package sweep import ( + "errors" + "math" "testing" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" @@ -237,3 +241,432 @@ func TestTxInputSetRequiredOutput(t *testing.T) { } require.True(t, set.enoughInput()) } + +// TestNewBudgetInputSet checks `NewBudgetInputSet` correctly validates the +// supplied inputs and returns the error. +func TestNewBudgetInputSet(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Pass an empty slice and expect an error. + set, err := NewBudgetInputSet([]pendingInput{}) + rt.ErrorContains(err, "inputs slice is empty") + rt.Nil(set) + + // Create two inputs with different deadline heights. + inp0 := createP2WKHInput(1000) + inp1 := createP2WKHInput(1000) + inp2 := createP2WKHInput(1000) + input0 := pendingInput{ + Input: inp0, + params: Params{ + Budget: 100, + DeadlineHeight: fn.None[int32](), + }, + } + input1 := pendingInput{ + Input: inp1, + params: Params{ + Budget: 100, + DeadlineHeight: fn.Some(int32(1)), + }, + } + input2 := pendingInput{ + Input: inp2, + params: Params{ + Budget: 100, + DeadlineHeight: fn.Some(int32(2)), + }, + } + + // Pass a slice of inputs with different deadline heights. + set, err = NewBudgetInputSet([]pendingInput{input1, input2}) + rt.ErrorContains(err, "inputs have different deadline heights") + rt.Nil(set) + + // Pass a slice of inputs that only one input has the deadline height. + set, err = NewBudgetInputSet([]pendingInput{input0, input2}) + rt.NoError(err) + rt.NotNil(set) + + // Pass a slice of inputs that are duplicates. + set, err = NewBudgetInputSet([]pendingInput{input1, input1}) + rt.ErrorContains(err, "duplicate inputs") + rt.Nil(set) +} + +// TestBudgetInputSetAddInput checks that `addInput` correctly updates the +// budget of the input set. +func TestBudgetInputSetAddInput(t *testing.T) { + t.Parallel() + + // Create a testing input with a budget of 100 satoshis. + input := createP2WKHInput(1000) + pi := &pendingInput{ + Input: input, + params: Params{ + Budget: 100, + }, + } + + // Initialize an input set, which adds the above input. + set, err := NewBudgetInputSet([]pendingInput{*pi}) + require.NoError(t, err) + + // Add the input to the set again. + set.addInput(*pi) + + // The set should now have two inputs. + require.Len(t, set.inputs, 2) + require.Equal(t, pi, set.inputs[0]) + require.Equal(t, pi, set.inputs[1]) + + // The set should have a budget of 200 satoshis. + require.Equal(t, btcutil.Amount(200), set.Budget()) +} + +// TestNeedWalletInput checks that NeedWalletInput correctly determines if a +// wallet input is needed. +func TestNeedWalletInput(t *testing.T) { + t.Parallel() + + // Create a mock input that doesn't have required outputs. + mockInput := &input.MockInput{} + mockInput.On("RequiredTxOut").Return(nil) + defer mockInput.AssertExpectations(t) + + // Create a mock input that has required outputs. + mockInputRequireOutput := &input.MockInput{} + mockInputRequireOutput.On("RequiredTxOut").Return(&wire.TxOut{}) + defer mockInputRequireOutput.AssertExpectations(t) + + // We now create two pending inputs each has a budget of 100 satoshis. + const budget = 100 + + // Create the pending input that doesn't have a required output. + piBudget := &pendingInput{ + Input: mockInput, + params: Params{Budget: budget}, + } + + // Create the pending input that has a required output. + piRequireOutput := &pendingInput{ + Input: mockInputRequireOutput, + params: Params{Budget: budget}, + } + + testCases := []struct { + name string + setupInputs func() []*pendingInput + need bool + }{ + { + // When there are no pending inputs, we won't need a + // wallet input. Technically this should be an invalid + // state. + name: "no inputs", + setupInputs: func() []*pendingInput { + return nil + }, + need: false, + }, + { + // When there's no required output, we don't need a + // wallet input. + name: "no required outputs", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + + return []*pendingInput{piBudget} + }, + need: false, + }, + { + // When the output value cannot cover the budget, we + // need a wallet input. + name: "output value cannot cover budget", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget - 1, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + + // These two methods are only invoked when the + // unit test is running with a logger. + mockInput.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}, + ).Maybe() + mockInput.On("WitnessType").Return( + input.CommitmentAnchor, + ).Maybe() + + return []*pendingInput{piBudget} + }, + need: true, + }, + { + // When there's only inputs that require outputs, we + // need wallet inputs. + name: "only required outputs", + setupInputs: func() []*pendingInput { + return []*pendingInput{piRequireOutput} + }, + need: true, + }, + { + // When there's a mix of inputs, but the borrowable + // budget cannot cover the required, we need a wallet + // input. + name: "not enough budget to be borrowed", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + // + // NOTE: the value is exactly the same as the + // budget so we can't borrow any more. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + + return []*pendingInput{ + piBudget, piRequireOutput, + } + }, + need: true, + }, + { + // When there's a mix of inputs, and the budget can be + // borrowed covers the required, we don't need wallet + // inputs. + name: "enough budget to be borrowed", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + // + // NOTE: the value is exactly the same as the + // budget so we can't borrow any more. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget * 2, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + piBudget.Input = mockInput + + return []*pendingInput{ + piBudget, piRequireOutput, + } + }, + need: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup testing inputs. + inputs := tc.setupInputs() + + // Initialize an input set, which adds the testing + // inputs. + set := &BudgetInputSet{inputs: inputs} + + result := set.NeedWalletInput() + require.Equal(t, tc.need, result) + }) + } +} + +// TestAddWalletInputReturnErr tests the three possible errors returned from +// AddWalletInputs: +// - error from ListUnspentWitnessFromDefaultAccount. +// - error from createWalletTxInput. +// - error when wallet doesn't have utxos. +func TestAddWalletInputReturnErr(t *testing.T) { + t.Parallel() + + wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + + // Initialize an empty input set. + set := &BudgetInputSet{} + + // Specify the min and max confs used in + // ListUnspentWitnessFromDefaultAccount. + min, max := int32(1), int32(math.MaxInt32) + + // Mock the wallet to return an error. + dummyErr := errors.New("dummy error") + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return(nil, dummyErr).Once() + + // Check that the error is returned from + // ListUnspentWitnessFromDefaultAccount. + err := set.AddWalletInputs(wallet) + require.ErrorIs(t, err, dummyErr) + + // Create an utxo with unknown address type to trigger an error. + utxo := &lnwallet.Utxo{ + AddressType: lnwallet.UnknownAddressType, + } + + // Mock the wallet to return the above utxo. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{utxo}, nil).Once() + + // Check that the error is returned from createWalletTxInput. + err = set.AddWalletInputs(wallet) + require.Error(t, err) + + // Mock the wallet to return empty utxos. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{}, nil).Once() + + // Check that the error is returned from not having wallet inputs. + err = set.AddWalletInputs(wallet) + require.ErrorIs(t, err, ErrNotEnoughInputs) +} + +// TestAddWalletInputNotEnoughInputs checks that when there are not enough +// wallet utxos, an error is returned and the budget set is reset to its +// initial state. +func TestAddWalletInputNotEnoughInputs(t *testing.T) { + t.Parallel() + + wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + + // Specify the min and max confs used in + // ListUnspentWitnessFromDefaultAccount. + min, max := int32(1), int32(math.MaxInt32) + + // Assume the desired budget is 10k satoshis. + const budget = 10_000 + + // Create a mock input that has required outputs. + mockInput := &input.MockInput{} + mockInput.On("RequiredTxOut").Return(&wire.TxOut{}) + defer mockInput.AssertExpectations(t) + + // Create a pending input that requires 10k satoshis. + pi := &pendingInput{ + Input: mockInput, + params: Params{Budget: budget}, + } + + // Create a wallet utxo that cannot cover the budget. + utxo := &lnwallet.Utxo{ + AddressType: lnwallet.WitnessPubKey, + Value: budget - 1, + } + + // Mock the wallet to return the above utxo. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{utxo}, nil).Once() + + // Initialize an input set with the pending input. + set := BudgetInputSet{inputs: []*pendingInput{pi}} + + // Add wallet inputs to the input set, which should give us an error as + // the wallet cannot cover the budget. + err := set.AddWalletInputs(wallet) + require.ErrorIs(t, err, ErrNotEnoughInputs) + + // Check that the budget set is reverted to its initial state. + require.Len(t, set.inputs, 1) + require.Equal(t, pi, set.inputs[0]) +} + +// TestAddWalletInputSuccess checks that when there are enough wallet utxos, +// they are added to the input set. +func TestAddWalletInputSuccess(t *testing.T) { + t.Parallel() + + wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + + // Specify the min and max confs used in + // ListUnspentWitnessFromDefaultAccount. + min, max := int32(1), int32(math.MaxInt32) + + // Assume the desired budget is 10k satoshis. + const budget = 10_000 + + // Create a mock input that has required outputs. + mockInput := &input.MockInput{} + mockInput.On("RequiredTxOut").Return(&wire.TxOut{}) + defer mockInput.AssertExpectations(t) + + // Create a pending input that requires 10k satoshis. + deadline := int32(1000) + pi := &pendingInput{ + Input: mockInput, + params: Params{ + Budget: budget, + DeadlineHeight: fn.Some(deadline), + }, + } + + // Mock methods used in loggings. + // + // NOTE: these methods are not functional as they are only used for + // loggings in debug or trace mode so we use arbitrary values. + mockInput.On("OutPoint").Return(&wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput.On("WitnessType").Return(input.CommitmentAnchor) + + // Create a wallet utxo that cannot cover the budget. + utxo := &lnwallet.Utxo{ + AddressType: lnwallet.WitnessPubKey, + Value: budget - 1, + } + + // Mock the wallet to return the two utxos which can cover the budget. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{utxo, utxo}, nil).Once() + + // Initialize an input set with the pending input. + set, err := NewBudgetInputSet([]pendingInput{*pi}) + require.NoError(t, err) + + // Add wallet inputs to the input set, which should give us an error as + // the wallet cannot cover the budget. + err = set.AddWalletInputs(wallet) + require.NoError(t, err) + + // Check that the budget set is updated. + require.Len(t, set.inputs, 3) + + // The first input is the pending input. + require.Equal(t, pi, set.inputs[0]) + + // The second and third inputs are wallet inputs that have + // DeadlineHeight set. + input2Deadline := set.inputs[1].params.DeadlineHeight + require.Equal(t, deadline, input2Deadline.UnsafeFromSome()) + input3Deadline := set.inputs[2].params.DeadlineHeight + require.Equal(t, deadline, input3Deadline.UnsafeFromSome()) + + // Finally, check the interface methods. + require.EqualValues(t, budget, set.Budget()) + require.Equal(t, deadline, set.DeadlineHeight().UnsafeFromSome()) + // Weak check, a strong check is to open the slice and check each item. + require.Len(t, set.inputs, 3) +} From 9565c3b820afe3c078ed86e43b017280ccb20c1d Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 27 Feb 2024 21:54:48 +0800 Subject: [PATCH 05/19] sweep: introduce `BudgetAggregator` to cluster inputs by deadlines This commit adds `BudgetAggregator` as a new implementation of `UtxoAggregator`. This aggregator will group inputs by their deadline heights and create input sets that can be used directly by the fee bumper for fee calculations. --- input/mocks.go | 45 ++++ sweep/aggregator.go | 232 ++++++++++++++++++ sweep/aggregator_test.go | 510 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 787 insertions(+) diff --git a/input/mocks.go b/input/mocks.go index 965489effb..1fe6eb7656 100644 --- a/input/mocks.go +++ b/input/mocks.go @@ -123,3 +123,48 @@ func (m *MockInput) UnconfParent() *TxInfo { return info.(*TxInfo) } + +// MockWitnessType implements the `WitnessType` interface and is used by other +// packages for mock testing. +type MockWitnessType struct { + mock.Mock +} + +// Compile time assertion that MockWitnessType implements WitnessType. +var _ WitnessType = (*MockWitnessType)(nil) + +// String returns a human readable version of the WitnessType. +func (m *MockWitnessType) String() string { + args := m.Called() + + return args.String(0) +} + +// WitnessGenerator will return a WitnessGenerator function that an output uses +// to generate the witness and optionally the sigScript for a sweep +// transaction. +func (m *MockWitnessType) WitnessGenerator(signer Signer, + descriptor *SignDescriptor) WitnessGenerator { + + args := m.Called() + + return args.Get(0).(WitnessGenerator) +} + +// SizeUpperBound returns the maximum length of the witness of this WitnessType +// if it would be included in a tx. It also returns if the output itself is a +// nested p2sh output, if so then we need to take into account the extra +// sigScript data size. +func (m *MockWitnessType) SizeUpperBound() (int, bool, error) { + args := m.Called() + + return args.Int(0), args.Bool(1), args.Error(2) +} + +// AddWeightEstimation adds the estimated size of the witness in bytes to the +// given weight estimator. +func (m *MockWitnessType) AddWeightEstimation(e *TxWeightEstimator) error { + args := m.Called() + + return args.Error(0) +} diff --git a/sweep/aggregator.go b/sweep/aggregator.go index 379ff98296..58ac511320 100644 --- a/sweep/aggregator.go +++ b/sweep/aggregator.go @@ -3,7 +3,10 @@ package sweep import ( "sort" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) @@ -461,3 +464,232 @@ func zipClusters(as, bs []inputCluster) []inputCluster { return finalClusters } + +// BudgetAggregator is a budget-based aggregator that creates clusters based on +// deadlines and budgets of inputs. +type BudgetAggregator struct { + // estimator is used when crafting sweep transactions to estimate the + // necessary fee relative to the expected size of the sweep + // transaction. + estimator chainfee.Estimator + + // maxInputs specifies the maximum number of inputs allowed in a single + // sweep tx. + maxInputs uint32 +} + +// Compile-time constraint to ensure BudgetAggregator implements UtxoAggregator. +var _ UtxoAggregator = (*BudgetAggregator)(nil) + +// NewBudgetAggregator creates a new instance of a BudgetAggregator. +func NewBudgetAggregator(estimator chainfee.Estimator, + maxInputs uint32) *BudgetAggregator { + + return &BudgetAggregator{ + estimator: estimator, + maxInputs: maxInputs, + } +} + +// clusterGroup defines an alias for a set of inputs that are to be grouped. +type clusterGroup map[fn.Option[int32]][]pendingInput + +// ClusterInputs creates a list of input sets from pending inputs. +// 1. filter out inputs whose budget cannot cover min relay fee. +// 2. group the inputs into clusters based on their deadline height. +// 3. sort the inputs in each cluster by their budget. +// 4. optionally split a cluster if it exceeds the max input limit. +// 5. create input sets from each of the clusters. +func (b *BudgetAggregator) ClusterInputs(inputs pendingInputs) []InputSet { + // Filter out inputs that have a budget below min relay fee. + filteredInputs := b.filterInputs(inputs) + + // Create clusters to group inputs based on their deadline height. + clusters := make(clusterGroup, len(filteredInputs)) + + // Iterate all the inputs and group them based on their specified + // deadline heights. + for _, input := range filteredInputs { + height := input.params.DeadlineHeight + cluster, ok := clusters[height] + if !ok { + cluster = make([]pendingInput, 0) + } + + cluster = append(cluster, *input) + clusters[height] = cluster + } + + // Now that we have the clusters, we can create the input sets. + // + // NOTE: cannot pre-allocate the slice since we don't know the number + // of input sets in advance. + inputSets := make([]InputSet, 0) + for _, cluster := range clusters { + // Sort the inputs by their economical value. + sortedInputs := b.sortInputs(cluster) + + // Create input sets from the cluster. + sets := b.createInputSets(sortedInputs) + inputSets = append(inputSets, sets...) + } + + return inputSets +} + +// createInputSet takes a set of inputs which share the same deadline height +// and turns them into a list of `InputSet`, each set is then used to create a +// sweep transaction. +func (b *BudgetAggregator) createInputSets(inputs []pendingInput) []InputSet { + // sets holds the InputSets that we will return. + sets := make([]InputSet, 0) + + // Copy the inputs to a new slice so we can modify it. + remainingInputs := make([]pendingInput, len(inputs)) + copy(remainingInputs, inputs) + + // If the number of inputs is greater than the max inputs allowed, we + // will split them into smaller clusters. + for uint32(len(remainingInputs)) > b.maxInputs { + log.Tracef("Cluster has %v inputs, max is %v, dividing...", + len(inputs), b.maxInputs) + + // Copy the inputs to be put into the new set, and update the + // remaining inputs by removing currentInputs. + currentInputs := make([]pendingInput, b.maxInputs) + copy(currentInputs, remainingInputs[:b.maxInputs]) + remainingInputs = remainingInputs[b.maxInputs:] + + // Create an InputSet using the max allowed number of inputs. + set, err := NewBudgetInputSet(currentInputs) + if err != nil { + log.Errorf("unable to create input set: %v", err) + + continue + } + + sets = append(sets, set) + } + + // Create an InputSet from the remaining inputs. + if len(remainingInputs) > 0 { + set, err := NewBudgetInputSet(remainingInputs) + if err != nil { + log.Errorf("unable to create input set: %v", err) + return nil + } + + sets = append(sets, set) + } + + return sets +} + +// filterInputs filters out inputs that have a budget below the min relay fee +// or have a required output that's below the dust. +func (b *BudgetAggregator) filterInputs(inputs pendingInputs) pendingInputs { + // Get the current min relay fee for this round. + minFeeRate := b.estimator.RelayFeePerKW() + + // filterInputs stores a map of inputs that has a budget that at least + // can pay the minimal fee. + filteredInputs := make(pendingInputs, len(inputs)) + + // Iterate all the inputs and filter out the ones whose budget cannot + // cover the min fee. + for _, pi := range inputs { + op := pi.OutPoint() + + // Get the size and skip if there's an error. + size, _, err := pi.WitnessType().SizeUpperBound() + if err != nil { + log.Warnf("Skipped input=%v: cannot get its size: %v", + op, err) + + continue + } + + // Skip inputs that has too little budget. + minFee := minFeeRate.FeeForWeight(int64(size)) + if pi.params.Budget < minFee { + log.Warnf("Skipped input=%v: has budget=%v, but the "+ + "min fee requires %v", op, pi.params.Budget, + minFee) + + continue + } + + // If the input comes with a required tx out that is below + // dust, we won't add it. + // + // NOTE: only HtlcSecondLevelAnchorInput returns non-nil + // RequiredTxOut. + reqOut := pi.RequiredTxOut() + if reqOut != nil { + if isDustOutput(reqOut) { + log.Errorf("Rejected input=%v due to dust "+ + "required output=%v", op, reqOut.Value) + + continue + } + } + + filteredInputs[*op] = pi + } + + return filteredInputs +} + +// sortInputs sorts the inputs based on their economical value. +// +// NOTE: besides the forced inputs, the sorting won't make any difference +// because all the inputs are added to the same set. The exception is when the +// number of inputs exceeds the maxInputs limit, it requires us to split them +// into smaller clusters. In that case, the sorting will make a difference as +// the budgets of the clusters will be different. +func (b *BudgetAggregator) sortInputs(inputs []pendingInput) []pendingInput { + // sortedInputs is the final list of inputs sorted by their economical + // value. + sortedInputs := make([]pendingInput, 0, len(inputs)) + + // Copy the inputs. + sortedInputs = append(sortedInputs, inputs...) + + // Sort the inputs based on their budgets. + // + // NOTE: We can implement more sophisticated algorithm as the budget + // left is a function f(minFeeRate, size) = b1 - s1 * r > b2 - s2 * r, + // where b1 and b2 are budgets, s1 and s2 are sizes of the inputs. + sort.Slice(sortedInputs, func(i, j int) bool { + left := sortedInputs[i].params.Budget + right := sortedInputs[j].params.Budget + + // Make sure forced inputs are always put in the front. + leftForce := sortedInputs[i].params.Force + rightForce := sortedInputs[j].params.Force + + // If both are forced inputs, we return the one with the higher + // budget. If neither are forced inputs, we also return the one + // with the higher budget. + if leftForce == rightForce { + return left > right + } + + // Otherwise, it's either the left or the right is forced. We + // can simply return `leftForce` here as, if it's true, the + // left is forced and should be put in the front. Otherwise, + // the right is forced and should be put in the front. + return leftForce + }) + + return sortedInputs +} + +// isDustOutput checks if the given output is considered as dust. +func isDustOutput(output *wire.TxOut) bool { + // Fetch the dust limit for this output. + dustLimit := lnwallet.DustLimitForSize(len(output.PkScript)) + + // If the output is below the dust limit, we consider it dust. + return btcutil.Amount(output.Value) < dustLimit +} diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go index 2058464ad2..bee1db5293 100644 --- a/sweep/aggregator_test.go +++ b/sweep/aggregator_test.go @@ -1,14 +1,17 @@ package sweep import ( + "bytes" "errors" "reflect" "sort" "testing" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/require" @@ -421,3 +424,510 @@ func TestClusterByLockTime(t *testing.T) { }) } } + +// TestBudgetAggregatorFilterInputs checks that inputs with low budget are +// filtered out. +func TestBudgetAggregatorFilterInputs(t *testing.T) { + t.Parallel() + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + defer estimator.AssertExpectations(t) + + // Create a mock WitnessType that always return an error when trying to + // get its size upper bound. + wtErr := &input.MockWitnessType{} + defer wtErr.AssertExpectations(t) + + // Mock the `SizeUpperBound` method to return an error exactly once. + dummyErr := errors.New("dummy error") + wtErr.On("SizeUpperBound").Return(0, false, dummyErr).Once() + + // Create a mock WitnessType that gives the size. + wt := &input.MockWitnessType{} + defer wt.AssertExpectations(t) + + // Mock the `SizeUpperBound` method to return the size four times. + const wtSize = 100 + wt.On("SizeUpperBound").Return(wtSize, true, nil).Times(4) + + // Create a mock input that will be filtered out due to error. + inpErr := &input.MockInput{} + defer inpErr.AssertExpectations(t) + + // Mock the `WitnessType` method to return the erroring witness type. + inpErr.On("WitnessType").Return(wtErr).Once() + + // Mock the `OutPoint` method to return a unique outpoint. + opErr := wire.OutPoint{Hash: chainhash.Hash{1}} + inpErr.On("OutPoint").Return(&opErr).Once() + + // Mock the estimator to return a constant fee rate. + const minFeeRate = chainfee.SatPerKWeight(1000) + estimator.On("RelayFeePerKW").Return(minFeeRate).Once() + + var ( + // Define three budget values, one below the min fee rate, one + // above and one equal to it. + budgetLow = minFeeRate.FeeForWeight(wtSize) - 1 + budgetEqual = minFeeRate.FeeForWeight(wtSize) + budgetHigh = minFeeRate.FeeForWeight(wtSize) + 1 + + // Define three outpoints with different budget values. + opLow = wire.OutPoint{Hash: chainhash.Hash{2}} + opEqual = wire.OutPoint{Hash: chainhash.Hash{3}} + opHigh = wire.OutPoint{Hash: chainhash.Hash{4}} + + // Define an outpoint that has a dust required output. + opDust = wire.OutPoint{Hash: chainhash.Hash{5}} + ) + + // Create three mock inputs. + inpLow := &input.MockInput{} + defer inpLow.AssertExpectations(t) + + inpEqual := &input.MockInput{} + defer inpEqual.AssertExpectations(t) + + inpHigh := &input.MockInput{} + defer inpHigh.AssertExpectations(t) + + inpDust := &input.MockInput{} + defer inpDust.AssertExpectations(t) + + // Mock the `WitnessType` method to return the witness type. + inpLow.On("WitnessType").Return(wt) + inpEqual.On("WitnessType").Return(wt) + inpHigh.On("WitnessType").Return(wt) + inpDust.On("WitnessType").Return(wt) + + // Mock the `OutPoint` method to return the unique outpoint. + inpLow.On("OutPoint").Return(&opLow) + inpEqual.On("OutPoint").Return(&opEqual) + inpHigh.On("OutPoint").Return(&opHigh) + inpDust.On("OutPoint").Return(&opDust) + + // Mock the `RequiredTxOut` to return nil. + inpEqual.On("RequiredTxOut").Return(nil) + inpHigh.On("RequiredTxOut").Return(nil) + + // Mock the dust required output. + inpDust.On("RequiredTxOut").Return(&wire.TxOut{ + Value: 0, + PkScript: bytes.Repeat([]byte{0}, input.P2WSHSize), + }) + + // Create testing pending inputs. + inputs := pendingInputs{ + // The first input will be filtered out due to the error. + opErr: &pendingInput{ + Input: inpErr, + }, + + // The second input will be filtered out due to the budget. + opLow: &pendingInput{ + Input: inpLow, + params: Params{Budget: budgetLow}, + }, + + // The third input will be included. + opEqual: &pendingInput{ + Input: inpEqual, + params: Params{Budget: budgetEqual}, + }, + + // The fourth input will be included. + opHigh: &pendingInput{ + Input: inpHigh, + params: Params{Budget: budgetHigh}, + }, + + // The fifth input will be filtered out due to the dust + // required. + opDust: &pendingInput{ + Input: inpDust, + params: Params{Budget: budgetHigh}, + }, + } + + // Init the budget aggregator with the mocked estimator and zero max + // num of inputs. + b := NewBudgetAggregator(estimator, 0) + + // Call the method under test. + result := b.filterInputs(inputs) + + // Validate the expected inputs are returned. + require.Len(t, result, 2) + + // We expect only the inputs with budget equal or above the min fee to + // be included. + require.Contains(t, result, opEqual) + require.Contains(t, result, opHigh) +} + +// TestBudgetAggregatorSortInputs checks that inputs are sorted by based on +// their budgets and force flag. +func TestBudgetAggregatorSortInputs(t *testing.T) { + t.Parallel() + + var ( + // Create two budgets. + budgetLow = btcutil.Amount(1000) + budgetHight = budgetLow + btcutil.Amount(1000) + ) + + // Create an input with the low budget but forced. + inputLowForce := pendingInput{ + params: Params{ + Budget: budgetLow, + Force: true, + }, + } + + // Create an input with the low budget. + inputLow := pendingInput{ + params: Params{ + Budget: budgetLow, + }, + } + + // Create an input with the high budget and forced. + inputHighForce := pendingInput{ + params: Params{ + Budget: budgetHight, + Force: true, + }, + } + + // Create an input with the high budget. + inputHigh := pendingInput{ + params: Params{ + Budget: budgetHight, + }, + } + + // Create a testing pending inputs. + inputs := []pendingInput{ + inputLowForce, + inputLow, + inputHighForce, + inputHigh, + } + + // Init the budget aggregator with zero max num of inputs. + b := NewBudgetAggregator(nil, 0) + + // Call the method under test. + result := b.sortInputs(inputs) + require.Len(t, result, 4) + + // The first input should be the forced input with the high budget. + require.Equal(t, inputHighForce, result[0]) + + // The second input should be the forced input with the low budget. + require.Equal(t, inputLowForce, result[1]) + + // The third input should be the input with the high budget. + require.Equal(t, inputHigh, result[2]) + + // The fourth input should be the input with the low budget. + require.Equal(t, inputLow, result[3]) +} + +// TestBudgetAggregatorCreateInputSets checks that the budget aggregator +// creates input sets when the number of inputs exceeds the max number +// configed. +func TestBudgetAggregatorCreateInputSets(t *testing.T) { + t.Parallel() + + // Create mocks input that doesn't have required outputs. + mockInput1 := &input.MockInput{} + defer mockInput1.AssertExpectations(t) + mockInput2 := &input.MockInput{} + defer mockInput2.AssertExpectations(t) + mockInput3 := &input.MockInput{} + defer mockInput3.AssertExpectations(t) + mockInput4 := &input.MockInput{} + defer mockInput4.AssertExpectations(t) + + // Create testing pending inputs. + pi1 := pendingInput{ + Input: mockInput1, + params: Params{ + DeadlineHeight: fn.Some(int32(1)), + }, + } + pi2 := pendingInput{ + Input: mockInput2, + params: Params{ + DeadlineHeight: fn.Some(int32(1)), + }, + } + pi3 := pendingInput{ + Input: mockInput3, + params: Params{ + DeadlineHeight: fn.Some(int32(1)), + }, + } + pi4 := pendingInput{ + Input: mockInput4, + params: Params{ + // This input has a deadline height that is different + // from the other inputs. When grouped with other + // inputs, it will cause an error to be returned. + DeadlineHeight: fn.Some(int32(2)), + }, + } + + // Create a budget aggregator with max number of inputs set to 2. + b := NewBudgetAggregator(nil, 2) + + // Create test cases. + testCases := []struct { + name string + inputs []pendingInput + setupMock func() + expectedNumSets int + }{ + { + // When the number of inputs is below the max, a single + // input set is returned. + name: "num inputs below max", + inputs: []pendingInput{pi1}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + }, + expectedNumSets: 1, + }, + { + // When the number of inputs is equal to the max, a + // single input set is returned. + name: "num inputs equal to max", + inputs: []pendingInput{pi1, pi2}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput2.On("WitnessType").Return( + input.CommitmentAnchor) + + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput2.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{2}}) + }, + expectedNumSets: 1, + }, + { + // When the number of inputs is above the max, multiple + // input sets are returned. + name: "num inputs above max", + inputs: []pendingInput{pi1, pi2, pi3}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput2.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput3.On("WitnessType").Return( + input.CommitmentAnchor) + + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput2.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{2}}) + mockInput3.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{3}}) + }, + expectedNumSets: 2, + }, + { + // When the number of inputs is above the max, but an + // error is returned from creating the first set, it + // shouldn't affect the remaining inputs. + name: "num inputs above max with error", + inputs: []pendingInput{pi1, pi4, pi3}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput3.On("WitnessType").Return( + input.CommitmentAnchor) + + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput3.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{3}}) + mockInput4.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{2}}) + }, + expectedNumSets: 1, + }, + } + + // Iterate over the test cases. + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Setup the mocks. + tc.setupMock() + + // Call the method under test. + result := b.createInputSets(tc.inputs) + + // Validate the expected number of input sets are + // returned. + require.Len(t, result, tc.expectedNumSets) + }) + } +} + +// TestBudgetInputSetClusterInputs checks that the budget aggregator clusters +// inputs into input sets based on their deadline heights. +func TestBudgetInputSetClusterInputs(t *testing.T) { + t.Parallel() + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + defer estimator.AssertExpectations(t) + + // Create a mock WitnessType that gives the size. + wt := &input.MockWitnessType{} + defer wt.AssertExpectations(t) + + // Mock the `SizeUpperBound` method to return the size six times since + // we are using nine inputs. + const wtSize = 100 + wt.On("SizeUpperBound").Return(wtSize, true, nil).Times(9) + wt.On("String").Return("mock witness type") + + // Mock the estimator to return a constant fee rate. + const minFeeRate = chainfee.SatPerKWeight(1000) + estimator.On("RelayFeePerKW").Return(minFeeRate).Once() + + var ( + // Define two budget values, one below the min fee rate and one + // above it. + budgetLow = minFeeRate.FeeForWeight(wtSize) - 1 + budgetHigh = minFeeRate.FeeForWeight(wtSize) + 1 + + // Create three deadline heights, which means there are three + // groups of inputs to be expected. + deadlineNone = fn.None[int32]() + deadline1 = fn.Some(int32(1)) + deadline2 = fn.Some(int32(2)) + ) + + // Create testing pending inputs. + inputs := make(pendingInputs) + + // For each deadline height, create two inputs with different budgets, + // one below the min fee rate and one above it. We should see the lower + // one being filtered out. + for i, deadline := range []fn.Option[int32]{ + deadlineNone, deadline1, deadline2, + } { + // Define three outpoints. + opLow := wire.OutPoint{ + Hash: chainhash.Hash{byte(i)}, + Index: uint32(i), + } + opHigh1 := wire.OutPoint{ + Hash: chainhash.Hash{byte(i + 1000)}, + Index: uint32(i + 1000), + } + opHigh2 := wire.OutPoint{ + Hash: chainhash.Hash{byte(i + 2000)}, + Index: uint32(i + 2000), + } + + // Create mock inputs. + inpLow := &input.MockInput{} + defer inpLow.AssertExpectations(t) + + inpHigh1 := &input.MockInput{} + defer inpHigh1.AssertExpectations(t) + + inpHigh2 := &input.MockInput{} + defer inpHigh2.AssertExpectations(t) + + // Mock the `OutPoint` method to return the unique outpoint. + // + // We expect the low budget input to call this method once in + // `filterInputs`. + inpLow.On("OutPoint").Return(&opLow).Once() + + // We expect the high budget input to call this method three + // times, one in `filterInputs` and one in `createInputSet`, + // and one in `NewBudgetInputSet`. + inpHigh1.On("OutPoint").Return(&opHigh1).Times(3) + inpHigh2.On("OutPoint").Return(&opHigh2).Times(3) + + // Mock the `WitnessType` method to return the witness type. + inpLow.On("WitnessType").Return(wt) + inpHigh1.On("WitnessType").Return(wt) + inpHigh2.On("WitnessType").Return(wt) + + // Mock the `RequiredTxOut` to return nil. + inpHigh1.On("RequiredTxOut").Return(nil) + inpHigh2.On("RequiredTxOut").Return(nil) + + // Add the low input, which should be filtered out. + inputs[opLow] = &pendingInput{ + Input: inpLow, + params: Params{ + Budget: budgetLow, + DeadlineHeight: deadline, + }, + } + + // Add the high inputs, which should be included. + inputs[opHigh1] = &pendingInput{ + Input: inpHigh1, + params: Params{ + Budget: budgetHigh, + DeadlineHeight: deadline, + }, + } + inputs[opHigh2] = &pendingInput{ + Input: inpHigh2, + params: Params{ + Budget: budgetHigh, + DeadlineHeight: deadline, + }, + } + } + + // Create a budget aggregator with a max number of inputs set to 100. + b := NewBudgetAggregator(estimator, DefaultMaxInputsPerTx) + + // Call the method under test. + result := b.ClusterInputs(inputs) + + // We expect three input sets to be returned, one for each deadline. + require.Len(t, result, 3) + + // Check each input set has exactly two inputs. + deadlines := make(map[fn.Option[int32]]struct{}) + for _, set := range result { + // We expect two inputs in each set. + require.Len(t, set.Inputs(), 2) + + // We expect each set to have the expected budget. + require.Equal(t, budgetHigh*2, set.Budget()) + + // Save the deadlines. + deadlines[set.DeadlineHeight()] = struct{}{} + } + + // We expect to see all three deadlines. + require.Contains(t, deadlines, deadlineNone) + require.Contains(t, deadlines, deadline1) + require.Contains(t, deadlines, deadline2) +} From 481216f5032d30af4a613c4eada04961f222e8f5 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 17 Jan 2024 17:21:09 +0800 Subject: [PATCH 06/19] sweep: introduce `Bumper` interface to handle RBF This commit adds a new interface, `Bumper`, to handle RBF for a given input set. It's responsible for creating the sweeping tx using the input set, and monitors its confirmation status to decide whether a RBF should be attempted or not. We leave implementation details to future commits, and focus on mounting this `Bumper` interface to our sweeper in this commit. --- sweep/fee_bumper.go | 142 ++++++++++++++ sweep/fee_bumper_test.go | 52 +++++ sweep/mock_test.go | 19 ++ sweep/sweeper.go | 253 +++++++++++++++++++++---- sweep/sweeper_test.go | 397 +++++++++++++++++++++++++++++++++++++-- 5 files changed, 812 insertions(+), 51 deletions(-) create mode 100644 sweep/fee_bumper.go create mode 100644 sweep/fee_bumper_test.go diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go new file mode 100644 index 0000000000..c406149615 --- /dev/null +++ b/sweep/fee_bumper.go @@ -0,0 +1,142 @@ +package sweep + +import ( + "errors" + "fmt" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +var ( + // ErrInvalidBumpResult is returned when the bump result is invalid. + ErrInvalidBumpResult = errors.New("invalid bump result") +) + +// Bumper defines an interface that can be used by other subsystems for fee +// bumping. +type Bumper interface { + // Broadcast is used to publish the tx created from the given inputs + // specified in the request. It handles the tx creation, broadcasts it, + // and monitors its confirmation status for potential fee bumping. It + // returns a chan that the caller can use to receive updates about the + // broadcast result and potential RBF attempts. + Broadcast(req *BumpRequest) (<-chan *BumpResult, error) +} + +// BumpEvent represents the event of a fee bumping attempt. +type BumpEvent uint8 + +const ( + // TxPublished is sent when the broadcast attempt is finished. + TxPublished BumpEvent = iota + + // TxFailed is sent when the broadcast attempt fails. + TxFailed + + // TxReplaced is sent when the original tx is replaced by a new one. + TxReplaced + + // TxConfirmed is sent when the tx is confirmed. + TxConfirmed + + // sentinalEvent is used to check if an event is unknown. + sentinalEvent +) + +// String returns a human-readable string for the event. +func (e BumpEvent) String() string { + switch e { + case TxPublished: + return "Published" + case TxFailed: + return "Failed" + case TxReplaced: + return "Replaced" + case TxConfirmed: + return "Confirmed" + default: + return "Unknown" + } +} + +// Unknown returns true if the event is unknown. +func (e BumpEvent) Unknown() bool { + return e >= sentinalEvent +} + +// BumpRequest is used by the caller to give the Bumper the necessary info to +// create and manage potential fee bumps for a set of inputs. +type BumpRequest struct { + // Budget givens the total amount that can be used as fees by these + // inputs. + Budget btcutil.Amount + + // Inputs is the set of inputs to sweep. + Inputs []input.Input + + // DeadlineHeight is the block height at which the tx should be + // confirmed. + DeadlineHeight int32 + + // DeliveryAddress is the script to send the change output to. + DeliveryAddress []byte + + // MaxFeeRate is the maximum fee rate that can be used for fee bumping. + MaxFeeRate chainfee.SatPerKWeight +} + +// BumpResult is used by the Bumper to send updates about the tx being +// broadcast. +type BumpResult struct { + // Event is the type of event that the result is for. + Event BumpEvent + + // Tx is the tx being broadcast. + Tx *wire.MsgTx + + // ReplacedTx is the old, replaced tx if a fee bump is attempted. + ReplacedTx *wire.MsgTx + + // FeeRate is the fee rate used for the new tx. + FeeRate chainfee.SatPerKWeight + + // Fee is the fee paid by the new tx. + Fee btcutil.Amount + + // Err is the error that occurred during the broadcast. + Err error +} + +// Validate validates the BumpResult so it's safe to use. +func (b *BumpResult) Validate() error { + // Every result must have a tx. + if b.Tx == nil { + return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult) + } + + // Every result must have a known event. + if b.Event.Unknown() { + return fmt.Errorf("%w: unknown event", ErrInvalidBumpResult) + } + + // If it's a replacing event, it must have a replaced tx. + if b.Event == TxReplaced && b.ReplacedTx == nil { + return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult) + } + + // If it's a failed event, it must have an error. + if b.Event == TxFailed && b.Err == nil { + return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) + } + + // If it's a confirmed event, it must have a fee rate and fee. + if b.Event == TxConfirmed && (b.FeeRate == 0 || b.Fee == 0) { + return fmt.Errorf("%w: missing fee rate or fee", + ErrInvalidBumpResult) + } + + return nil +} diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go new file mode 100644 index 0000000000..22c247b2c5 --- /dev/null +++ b/sweep/fee_bumper_test.go @@ -0,0 +1,52 @@ +package sweep + +import ( + "testing" + + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" +) + +// TestBumpResultValidate tests the validate method of the BumpResult struct. +func TestBumpResultValidate(t *testing.T) { + t.Parallel() + + // An empty result will give an error. + b := BumpResult{} + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // Unknown event type will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: sentinalEvent, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A replacing event without a new tx will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxReplaced, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A failed event without a failure reason will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxFailed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A confirmed event without fee info will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxConfirmed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // Test a valid result. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxPublished, + } + require.NoError(t, b.Validate()) +} diff --git a/sweep/mock_test.go b/sweep/mock_test.go index f908cf6dbf..270c3844eb 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -462,3 +462,22 @@ func (m *MockInputSet) Budget() btcutil.Amount { return args.Get(0).(btcutil.Amount) } + +// MockBumper is a mock implementation of the interface Bumper. +type MockBumper struct { + mock.Mock +} + +// Compile-time constraint to ensure MockBumper implements Bumper. +var _ Bumper = (*MockBumper)(nil) + +// Broadcast broadcasts the transaction to the network. +func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { + args := m.Called(req) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(chan *BumpResult), args.Error(1) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 659c1d3408..505292912a 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -13,7 +13,6 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) @@ -41,6 +40,12 @@ var ( // an input is included in a publish attempt before giving up and // returning an error to the caller. DefaultMaxSweepAttempts = 10 + + // DefaultDeadlineDelta defines a default deadline delta (1 week) to be + // used when sweeping inputs with no deadline pressure. + // + // TODO(yy): make this configurable. + DefaultDeadlineDelta = int32(1008) ) // Params contains the parameters that control the sweeping process. @@ -317,6 +322,10 @@ type UtxoSweeper struct { // currentHeight is the best known height of the main chain. This is // updated whenever a new block epoch is received. currentHeight int32 + + // bumpResultChan is a channel that receives broadcast results from the + // TxPublisher. + bumpResultChan chan *BumpResult } // UtxoSweeperConfig contains dependencies of UtxoSweeper. @@ -364,6 +373,10 @@ type UtxoSweeperConfig struct { // Aggregator is used to group inputs into clusters based on its // implemention-specific strategy. Aggregator UtxoAggregator + + // Publisher is used to publish the sweep tx crafted here and monitors + // it for potential fee bumps. + Publisher Bumper } // Result is the struct that is pushed through the result channel. Callers can @@ -397,6 +410,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { pendingSweepsReqs: make(chan *pendingSweepsReq), quit: make(chan struct{}), pendingInputs: make(pendingInputs), + bumpResultChan: make(chan *BumpResult, 100), } } @@ -670,11 +684,16 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { err: err, } - // A new block comes in, update the bestHeight. - // - // TODO(yy): this is where we check our published transactions - // and perform RBF if needed. We'd also like to consult our fee - // bumper to get an updated fee rate. + case result := <-s.bumpResultChan: + // Handle the bump event. + err := s.handleBumpEvent(result) + if err != nil { + log.Errorf("Failed to handle bump event: %v", + err) + } + + // A new block comes in, update the bestHeight, perform a check + // over all pending inputs and publish sweeping txns if needed. case epoch, ok := <-blockEpochs: if !ok { // We should stop the sweeper before stopping @@ -779,8 +798,8 @@ func (s *UtxoSweeper) signalResult(pi *pendingInput, result Result) { } } -// sweep takes a set of preselected inputs, creates a sweep tx and publishes the -// tx. The output address is only marked as used if the publish succeeds. +// sweep takes a set of preselected inputs, creates a sweep tx and publishes +// the tx. The output address is only marked as used if the publish succeeds. func (s *UtxoSweeper) sweep(set InputSet) error { // Generate an output script if there isn't an unused script available. if s.currentOutputScript == nil { @@ -791,20 +810,21 @@ func (s *UtxoSweeper) sweep(set InputSet) error { s.currentOutputScript = pkScript } - // Create sweep tx. - tx, fee, err := createSweepTx( - set.Inputs(), nil, s.currentOutputScript, - uint32(s.currentHeight), set.FeeRate(), - s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer, - ) - if err != nil { - return fmt.Errorf("create sweep tx: %w", err) - } - - tr := &TxRecord{ - Txid: tx.TxHash(), - FeeRate: uint64(set.FeeRate()), - Fee: uint64(fee), + // Create a default deadline height, and replace it with set's + // DeadlineHeight if it's set. + deadlineHeight := s.currentHeight + DefaultDeadlineDelta + deadlineHeight = set.DeadlineHeight().UnwrapOr(deadlineHeight) + + // Create a fee bump request and ask the publisher to broadcast it. The + // publisher will then take over and start monitoring the tx for + // potential fee bump. + req := &BumpRequest{ + Inputs: set.Inputs(), + Budget: set.Budget(), + DeadlineHeight: deadlineHeight, + DeliveryAddress: s.currentOutputScript, + MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(), + // TODO(yy): pass the strategy here. } // Reschedule the inputs that we just tried to sweep. This is done in @@ -812,13 +832,9 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // publish attempts and rescue them in the next sweep. s.markInputsPendingPublish(set) - log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", - tx.TxHash(), len(tx.TxIn), s.currentHeight) - - // Publish the sweeping tx with customized label. - err = s.cfg.Wallet.PublishTransaction( - tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), - ) + // Broadcast will return a read-only chan that we will listen to for + // this publish result and future RBF attempt. + resp, err := s.cfg.Publisher.Broadcast(req) if err != nil { outpoints := make([]wire.OutPoint, len(set.Inputs())) for i, inp := range set.Inputs() { @@ -831,16 +847,11 @@ func (s *UtxoSweeper) sweep(set InputSet) error { return err } - // Inputs have been successfully published so we update their states. - err = s.markInputsPublished(tr, tx.TxIn) - if err != nil { - return err - } - - // If there's no error, remove the output script. Otherwise keep it so - // that it can be reused for the next transaction and causes no address - // inflation. - s.currentOutputScript = nil + // Successfully sent the broadcast attempt, we now handle the result by + // subscribing to the result chan and listen for future updates about + // this tx. + s.wg.Add(1) + go s.monitorFeeBumpResult(resp) return nil } @@ -1557,3 +1568,167 @@ func (s *UtxoSweeper) sweepPendingInputs(inputs pendingInputs) { } } } + +// monitorFeeBumpResult subscribes to the passed result chan to listen for +// future updates about the sweeping tx. +// +// NOTE: must run as a goroutine. +func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { + defer s.wg.Done() + + for { + select { + case r := <-resultChan: + // Validate the result is valid. + if err := r.Validate(); err != nil { + log.Errorf("Received invalid result: %v", err) + continue + } + + // Send the result back to the main event loop. + select { + case s.bumpResultChan <- r: + case <-s.quit: + log.Debug("Sweeper shutting down, skip " + + "sending bump result") + + return + } + + // The sweeping tx has been confirmed, we can exit the + // monitor now. + // + // TODO(yy): can instead remove the spend subscription + // in sweeper and rely solely on this event to mark + // inputs as Swept? + if r.Event == TxConfirmed || r.Event == TxFailed { + log.Debugf("Received %v for sweep tx %v, exit "+ + "fee bump monitor", r.Event, + r.Tx.TxHash()) + + return + } + + case <-s.quit: + log.Debugf("Sweeper shutting down, exit fee " + + "bump handler") + + return + } + } +} + +// handleBumpEventTxFailed handles the case where the tx has been failed to +// publish. +func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { + tx, err := r.Tx, r.Err + + log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) + + outpoints := make([]wire.OutPoint, 0, len(tx.TxIn)) + for _, inp := range tx.TxIn { + outpoints = append(outpoints, inp.PreviousOutPoint) + } + + // TODO(yy): should we also remove the failed tx from db? + s.markInputsPublishFailed(outpoints) + + return err +} + +// handleBumpEventTxReplaced handles the case where the sweeping tx has been +// replaced by a new one. +func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error { + oldTx := r.ReplacedTx + newTx := r.Tx + + // Prepare a new record to replace the old one. + tr := &TxRecord{ + Txid: newTx.TxHash(), + FeeRate: uint64(r.FeeRate), + Fee: uint64(r.Fee), + } + + // Get the old record for logging purpose. + oldTxid := oldTx.TxHash() + record, err := s.cfg.Store.GetTx(oldTxid) + if err != nil { + log.Errorf("Fetch tx record for %v: %v", oldTxid, err) + return err + } + + log.Infof("RBFed tx=%v(fee=%v, feerate=%v) with new tx=%v(fee=%v, "+ + "feerate=%v)", record.Txid, record.Fee, record.FeeRate, + tr.Txid, tr.Fee, tr.FeeRate) + + // The old sweeping tx has been replaced by a new one, we will update + // the tx record in the sweeper db. + // + // TODO(yy): we may also need to update the inputs in this tx to a new + // state. Suppose a replacing tx only spends a subset of the inputs + // here, we'd end up with the rest being marked as `StatePublished` and + // won't be aggregated in the next sweep. Atm it's fine as we always + // RBF the same input set. + if err := s.cfg.Store.DeleteTx(oldTxid); err != nil { + log.Errorf("Delete tx record for %v: %v", oldTxid, err) + return err + } + + // Mark the inputs as published using the replacing tx. + return s.markInputsPublished(tr, r.Tx.TxIn) +} + +// handleBumpEventTxPublished handles the case where the sweeping tx has been +// successfully published. +func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { + tx := r.Tx + tr := &TxRecord{ + Txid: tx.TxHash(), + FeeRate: uint64(r.FeeRate), + Fee: uint64(r.Fee), + } + + // Inputs have been successfully published so we update their + // states. + err := s.markInputsPublished(tr, tx.TxIn) + if err != nil { + return err + } + + log.Debugf("Published sweep tx %v, num_inputs=%v, height=%v", + tx.TxHash(), len(tx.TxIn), s.currentHeight) + + // If there's no error, remove the output script. Otherwise + // keep it so that it can be reused for the next transaction + // and causes no address inflation. + s.currentOutputScript = nil + + return nil +} + +// handleBumpEvent handles the result sent from the bumper based on its event +// type. +// +// NOTE: TxConfirmed event is not handled, since we already subscribe to the +// input's spending event, we don't need to do anything here. +func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { + log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash()) + + switch r.Event { + // The tx has been published, we update the inputs' state and create a + // record to be stored in the sweeper db. + case TxPublished: + return s.handleBumpEventTxPublished(r) + + // The tx has failed, we update the inputs' state. + case TxFailed: + return s.handleBumpEventTxFailed(r) + + // The tx has been replaced, we will remove the old tx and replace it + // with the new one. + case TxReplaced: + return s.handleBumpEventTxReplaced(r) + } + + return nil +} diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index f6891db92a..db26a4e75f 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -33,6 +33,8 @@ var ( testMaxInputsPerTx = uint32(3) defaultFeePref = Params{Fee: FeeEstimateInfo{ConfTarget: 1}} + + errDummy = errors.New("dummy error") ) type sweeperTestContext struct { @@ -137,6 +139,12 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { currentHeight: mockChainHeight, } + // Create a mock fee bumper. + mockBumper := &MockBumper{} + t.Cleanup(func() { + mockBumper.AssertExpectations(t) + }) + ctx.sweeper = New(&UtxoSweeperConfig{ Notifier: notifier, Wallet: backend, @@ -153,6 +161,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { MaxSweepAttempts: testMaxSweepAttempts, MaxFeeRate: DefaultMaxFeeRate, Aggregator: aggregator, + Publisher: mockBumper, }) ctx.sweeper.Start() @@ -2410,16 +2419,27 @@ func TestSweepPendingInputs(t *testing.T) { // Create a mock wallet and aggregator. wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + aggregator := &mockUtxoAggregator{} + defer aggregator.AssertExpectations(t) + + publisher := &MockBumper{} + defer publisher.AssertExpectations(t) // Create a test sweeper. s := New(&UtxoSweeperConfig{ Wallet: wallet, Aggregator: aggregator, + Publisher: publisher, + GenSweepScript: func() ([]byte, error) { + return testPubKey.SerializeCompressed(), nil + }, }) // Create an input set that needs wallet inputs. setNeedWallet := &MockInputSet{} + defer setNeedWallet.AssertExpectations(t) // Mock this set to ask for wallet input. setNeedWallet.On("NeedWalletInput").Return(true).Once() @@ -2430,15 +2450,18 @@ func TestSweepPendingInputs(t *testing.T) { // Create an input set that doesn't need wallet inputs. normalSet := &MockInputSet{} + defer normalSet.AssertExpectations(t) + normalSet.On("NeedWalletInput").Return(false).Once() // Mock the methods used in `sweep`. This is not important for this // unit test. - feeRate := chainfee.SatPerKWeight(1000) - setNeedWallet.On("Inputs").Return(nil).Once() - setNeedWallet.On("FeeRate").Return(feeRate).Once() - normalSet.On("Inputs").Return(nil).Once() - normalSet.On("FeeRate").Return(feeRate).Once() + setNeedWallet.On("Inputs").Return(nil).Times(4) + setNeedWallet.On("DeadlineHeight").Return(fn.None[int32]()).Once() + setNeedWallet.On("Budget").Return(btcutil.Amount(1)).Once() + normalSet.On("Inputs").Return(nil).Times(4) + normalSet.On("DeadlineHeight").Return(fn.None[int32]()).Once() + normalSet.On("Budget").Return(btcutil.Amount(1)).Once() // Make pending inputs for testing. We don't need real values here as // the returned clusters are mocked. @@ -2449,19 +2472,369 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet, normalSet, }) - // Set change output script to an invalid value. This should cause the + // Mock `Broadcast` to return an error. This should cause the // `createSweepTx` inside `sweep` to fail. This is done so we can // terminate the method early as we are only interested in testing the // workflow in `sweepPendingInputs`. We don't need to test `sweep` here // as it should be tested in its own unit test. - s.currentOutputScript = []byte{1} + dummyErr := errors.New("dummy error") + publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice() // Call the method under test. s.sweepPendingInputs(pis) +} + +// TestHandleBumpEventTxFailed checks that the sweeper correctly handles the +// case where the bump event tx fails to be published. +func TestHandleBumpEventTxFailed(t *testing.T) { + t.Parallel() + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + var ( + // Create four testing outpoints. + op1 = wire.OutPoint{Hash: chainhash.Hash{1}} + op2 = wire.OutPoint{Hash: chainhash.Hash{2}} + op3 = wire.OutPoint{Hash: chainhash.Hash{3}} + opNotExist = wire.OutPoint{Hash: chainhash.Hash{4}} + ) + + // Create three mock inputs. + input1 := &input.MockInput{} + defer input1.AssertExpectations(t) + + input2 := &input.MockInput{} + defer input2.AssertExpectations(t) + + input3 := &input.MockInput{} + defer input3.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op1: &pendingInput{Input: input1, state: StatePendingPublish}, + op2: &pendingInput{Input: input2, state: StatePendingPublish}, + op3: &pendingInput{Input: input3, state: StatePendingPublish}, + } + + // Create a testing tx that spends the first two inputs. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op1}, + {PreviousOutPoint: op2}, + {PreviousOutPoint: opNotExist}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: tx, + Event: TxFailed, + Err: errDummy, + } + + // Call the method under test. + err := s.handleBumpEvent(br) + require.ErrorIs(t, err, errDummy) + + // Assert the states of the first two inputs are updated. + require.Equal(t, StatePublishFailed, s.pendingInputs[op1].state) + require.Equal(t, StatePublishFailed, s.pendingInputs[op2].state) + + // Assert the state of the third input is not updated. + require.Equal(t, StatePendingPublish, s.pendingInputs[op3].state) + + // Assert the non-existing input is not added to the pending inputs. + require.NotContains(t, s.pendingInputs, opNotExist) +} + +// TestHandleBumpEventTxReplaced checks that the sweeper correctly handles the +// case where the bump event tx is replaced. +func TestHandleBumpEventTxReplaced(t *testing.T) { + t.Parallel() + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a replacement tx. + replacementTx := &wire.MsgTx{ + LockTime: 2, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: replacementTx, + ReplacedTx: tx, + Event: TxReplaced, + } + + // Mock the store to return an error. + dummyErr := errors.New("dummy error") + store.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once() + + // Call the method under test and assert the error is returned. + err := s.handleBumpEventTxReplaced(br) + require.ErrorIs(t, err, dummyErr) + + // Mock the store to return the old tx record. + store.On("GetTx", tx.TxHash()).Return(&TxRecord{ + Txid: tx.TxHash(), + }, nil).Once() + + // Mock an error returned when deleting the old tx record. + store.On("DeleteTx", tx.TxHash()).Return(dummyErr).Once() + + // Call the method under test and assert the error is returned. + err = s.handleBumpEventTxReplaced(br) + require.ErrorIs(t, err, dummyErr) + + // Mock the store to return the old tx record and delete it without + // error. + store.On("GetTx", tx.TxHash()).Return(&TxRecord{ + Txid: tx.TxHash(), + }, nil).Once() + store.On("DeleteTx", tx.TxHash()).Return(nil).Once() + + // Mock the store to save the new tx record. + store.On("StoreTx", &TxRecord{ + Txid: replacementTx.TxHash(), + Published: true, + }).Return(nil).Once() + + // Call the method under test. + err = s.handleBumpEventTxReplaced(br) + require.NoError(t, err) + + // Assert the state of the input is updated. + require.Equal(t, StatePublished, s.pendingInputs[op].state) +} - // Assert mocked methods are called as expected. - wallet.AssertExpectations(t) - aggregator.AssertExpectations(t) - setNeedWallet.AssertExpectations(t) - normalSet.AssertExpectations(t) +// TestHandleBumpEventTxPublished checks that the sweeper correctly handles the +// case where the bump event tx is published. +func TestHandleBumpEventTxPublished(t *testing.T) { + t.Parallel() + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: tx, + Event: TxPublished, + } + + // Mock the store to save the new tx record. + store.On("StoreTx", &TxRecord{ + Txid: tx.TxHash(), + Published: true, + }).Return(nil).Once() + + // Call the method under test. + err := s.handleBumpEventTxPublished(br) + require.NoError(t, err) + + // Assert the state of the input is updated. + require.Equal(t, StatePublished, s.pendingInputs[op].state) +} + +// TestMonitorFeeBumpResult checks that the fee bump monitor loop correctly +// exits when the sweeper is stopped, the tx is confirmed or failed. +func TestMonitorFeeBumpResult(t *testing.T) { + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + testCases := []struct { + name string + setupResultChan func() <-chan *BumpResult + shouldExit bool + }{ + { + // When a tx confirmed event is received, we expect to + // exit the monitor loop. + name: "tx confirmed", + // We send a result with TxConfirmed event to the + // result channel. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxConfirmed, + Fee: 10000, + FeeRate: 100, + } + + return resultChan + }, + shouldExit: true, + }, + { + // When a tx failed event is received, we expect to + // exit the monitor loop. + name: "tx failed", + // We send a result with TxConfirmed event to the + // result channel. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxFailed, + Err: errDummy, + } + + return resultChan + }, + shouldExit: true, + }, + { + // When processing non-confirmed events, the monitor + // should not exit. + name: "no exit on normal event", + // We send a result with TxPublished and mock the + // method `StoreTx` to return nil. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxPublished, + } + + return resultChan + }, + shouldExit: false, + }, { + // When the sweeper is shutting down, the monitor loop + // should exit. + name: "exit on sweeper shutdown", + // We don't send anything but quit the sweeper. + setupResultChan: func() <-chan *BumpResult { + close(s.quit) + + return nil + }, + shouldExit: true, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Setup the testing result channel. + resultChan := tc.setupResultChan() + + // Create a done chan that's used to signal the monitor + // has exited. + done := make(chan struct{}) + + s.wg.Add(1) + go func() { + s.monitorFeeBumpResult(resultChan) + close(done) + }() + + // The monitor is expected to exit, we check it's done + // in one second or fail. + if tc.shouldExit { + select { + case <-done: + case <-time.After(1 * time.Second): + require.Fail(t, "monitor not exited") + } + + return + } + + // The monitor should not exit, check it doesn't close + // the `done` channel within one second. + select { + case <-done: + require.Fail(t, "monitor exited") + case <-time.After(1 * time.Second): + } + }) + } } From 818d1c2cb05bdcfcb0b58d733fd9ea42aadce918 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 28 Feb 2024 23:00:43 +0800 Subject: [PATCH 07/19] sweeper: fix existing sweeper tests --- sweep/sweeper_test.go | 1250 ++++++++++++++++++++++++++++++++++------- 1 file changed, 1032 insertions(+), 218 deletions(-) diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index db26a4e75f..f18cc49177 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -2,8 +2,10 @@ package sweep import ( "errors" + "fmt" "os" "runtime/pprof" + "sync/atomic" "testing" "time" @@ -19,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" lnmock "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/mock" @@ -45,6 +48,7 @@ type sweeperTestContext struct { estimator *mockFeeEstimator backend *mockBackend store SweeperStore + publisher *MockBumper publishChan chan wire.MsgTx currentHeight int32 @@ -52,7 +56,7 @@ type sweeperTestContext struct { var ( spendableInputs []*input.BaseInput - testInputCount int + testInputCount atomic.Uint64 testPubKey, _ = btcec.ParsePubKey([]byte{ 0x04, 0x11, 0xdb, 0x93, 0xe1, 0xdc, 0xdb, 0x8a, @@ -69,7 +73,7 @@ var ( func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput { hash := chainhash.Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - byte(testInputCount + 1)} + byte(testInputCount.Add(1))} input := input.MakeBaseInput( &wire.OutPoint{ @@ -88,8 +92,6 @@ func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput nil, ) - testInputCount++ - return input } @@ -129,6 +131,12 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { testMaxInputsPerTx, ) + // Create a mock fee bumper. + mockBumper := &MockBumper{} + t.Cleanup(func() { + mockBumper.AssertExpectations(t) + }) + ctx := &sweeperTestContext{ notifier: notifier, publishChan: backend.publishChan, @@ -137,14 +145,9 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { backend: backend, store: store, currentHeight: mockChainHeight, + publisher: mockBumper, } - // Create a mock fee bumper. - mockBumper := &MockBumper{} - t.Cleanup(func() { - mockBumper.AssertExpectations(t) - }) - ctx.sweeper = New(&UtxoSweeperConfig{ Notifier: notifier, Wallet: backend, @@ -347,27 +350,80 @@ func assertTxFeeRate(t *testing.T, tx *wire.MsgTx, } } +// assertNumSweeps asserts that the expected number of sweeps has been found in +// the sweeper's store. +func assertNumSweeps(t *testing.T, sweeper *UtxoSweeper, num int) { + err := wait.NoError(func() error { + sweeps, err := sweeper.ListSweeps() + if err != nil { + return err + } + + if len(sweeps) != num { + return fmt.Errorf("want %d sweeps, got %d", + num, len(sweeps)) + } + + return nil + }, 5*time.Second) + require.NoError(t, err, "timeout checking num of sweeps") +} + // TestSuccess tests the sweeper happy flow. func TestSuccess(t *testing.T) { ctx := createSweeperTestContext(t) + inp := spendableInputs[0] + // Sweeping an input without a fee preference should result in an error. - _, err := ctx.sweeper.SweepInput(spendableInputs[0], Params{ + _, err := ctx.sweeper.SweepInput(inp, Params{ Fee: &FeeEstimateInfo{}, }) - if err != ErrNoFeePreference { - t.Fatalf("expected ErrNoFeePreference, got %v", err) - } + require.ErrorIs(t, err, ErrNoFeePreference) + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{{ + PreviousOutPoint: *inp.OutPoint(), + }}, + } - resultChan, err := ctx.sweeper.SweepInput( - spendableInputs[0], defaultFeePref, - ) - if err != nil { - t.Fatal(err) - } + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }) + + resultChan, err := ctx.sweeper.SweepInput(inp, defaultFeePref) + require.NoError(t, err) sweepTx := ctx.receiveTx() + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + + // Mine a block to confirm the sweep tx. ctx.backend.mine() select { @@ -411,15 +467,52 @@ func TestDust(t *testing.T) { // Sweep another input that brings the tx output above the dust limit. largeInput := createTestInput(100000, input.CommitmentTimeLock) + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *largeInput.OutPoint()}, + {PreviousOutPoint: *dustInput.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }) + _, err = ctx.sweeper.SweepInput(&largeInput, defaultFeePref) require.NoError(t, err) // The second input brings the sweep output above the dust limit. We // expect a sweep tx now. - sweepTx := ctx.receiveTx() require.Len(t, sweepTx.TxIn, 2, "unexpected num of tx inputs") + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.backend.mine() ctx.finish(1) @@ -442,29 +535,53 @@ func TestWalletUtxo(t *testing.T) { // sats. The tx yield becomes then 294-180 = 114 sats. dustInput := createTestInput(294, input.WitnessKeyHash) + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *dustInput.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }) + _, err := ctx.sweeper.SweepInput( &dustInput, Params{Fee: FeeEstimateInfo{FeeRate: chainfee.FeePerKwFloor}}, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sweepTx := ctx.receiveTx() - if len(sweepTx.TxIn) != 2 { - t.Fatalf("Expected tx to sweep 2 inputs, but contains %v "+ - "inputs instead", len(sweepTx.TxIn)) - } - // Calculate expected output value based on wallet utxo of 1_000_000 - // sats. - expectedOutputValue := int64(294 + 1_000_000 - 180) - if sweepTx.TxOut[0].Value != expectedOutputValue { - t.Fatalf("Expected output value of %v, but got %v", - expectedOutputValue, sweepTx.TxOut[0].Value) - } + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) ctx.backend.mine() + + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.finish(1) } @@ -479,28 +596,50 @@ func TestNegativeInput(t *testing.T) { largeInputResult, err := ctx.sweeper.SweepInput( &largeInput, defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Sweep an additional input with a negative net yield. The weight of // the HtlcAcceptedRemoteSuccess input type adds more in fees than its // value at the current fee level. negInput := createTestInput(2900, input.HtlcOfferedRemoteTimeout) negInputResult, err := ctx.sweeper.SweepInput(&negInput, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Sweep a third input that has a smaller output than the previous one, // but yields positively because of its lower weight. positiveInput := createTestInput(2800, input.CommitmentNoDelay) + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *largeInput.OutPoint()}, + {PreviousOutPoint: *positiveInput.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + positiveInputResult, err := ctx.sweeper.SweepInput( &positiveInput, defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // We expect that a sweep tx is published now, but it should only // contain the large input. The negative input should stay out of sweeps @@ -508,8 +647,19 @@ func TestNegativeInput(t *testing.T) { sweepTx1 := ctx.receiveTx() assertTxSweepsInputs(t, &sweepTx1, &largeInput, &positiveInput) + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(largeInputResult, nil) ctx.expectResult(positiveInputResult, nil) @@ -518,18 +668,56 @@ func TestNegativeInput(t *testing.T) { // Create another large input. secondLargeInput := createTestInput(100000, input.CommitmentNoDelay) + + // Mock the Broadcast method to succeed. + bumpResultChan = make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *negInput.OutPoint()}, + {PreviousOutPoint: *secondLargeInput. + OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + secondLargeInputResult, err := ctx.sweeper.SweepInput( &secondLargeInput, defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sweepTx2 := ctx.receiveTx() assertTxSweepsInputs(t, &sweepTx2, &secondLargeInput, &negInput) + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) + ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(secondLargeInputResult, nil) ctx.expectResult(negInputResult, nil) @@ -540,30 +728,96 @@ func TestNegativeInput(t *testing.T) { func TestChunks(t *testing.T) { ctx := createSweeperTestContext(t) + // Mock the Broadcast method to succeed on the first chunk. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + //nolint:lll + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *spendableInputs[0].OutPoint()}, + {PreviousOutPoint: *spendableInputs[1].OutPoint()}, + {PreviousOutPoint: *spendableInputs[2].OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second chunk. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + //nolint:lll + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *spendableInputs[3].OutPoint()}, + {PreviousOutPoint: *spendableInputs[4].OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + // Sweep five inputs. for _, input := range spendableInputs[:5] { _, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) } // We expect two txes to be published because of the max input count of // three. sweepTx1 := ctx.receiveTx() - if len(sweepTx1.TxIn) != 3 { - t.Fatalf("Expected first tx to sweep 3 inputs, but contains %v "+ - "inputs instead", len(sweepTx1.TxIn)) - } + require.Len(t, sweepTx1.TxIn, 3) sweepTx2 := ctx.receiveTx() - if len(sweepTx2.TxIn) != 2 { - t.Fatalf("Expected first tx to sweep 2 inputs, but contains %v "+ - "inputs instead", len(sweepTx1.TxIn)) - } + require.Len(t, sweepTx2.TxIn, 2) + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.finish(1) } @@ -581,39 +835,60 @@ func TestRemoteSpend(t *testing.T) { func testRemoteSpend(t *testing.T, postSweep bool) { ctx := createSweeperTestContext(t) + // Create a fake sweep tx that spends the second input as the first + // will be spent by the remote. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *spendableInputs[1].OutPoint()}, + }, + } + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput( spendableInputs[0], defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resultChan2, err := ctx.sweeper.SweepInput( spendableInputs[1], defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Spend the input with an unknown tx. remoteTx := &wire.MsgTx{ TxIn: []*wire.TxIn{ - { - PreviousOutPoint: *(spendableInputs[0].OutPoint()), - }, + {PreviousOutPoint: *(spendableInputs[0].OutPoint())}, }, } err = ctx.backend.publishTransaction(remoteTx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if postSweep { - // Tx publication by sweeper returns ErrDoubleSpend. Sweeper // will retry the inputs without reporting a result. It could be // spent by the remote party. ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) } ctx.backend.mine() @@ -633,13 +908,21 @@ func testRemoteSpend(t *testing.T, postSweep bool) { if !postSweep { // Assert that the sweeper sweeps the remaining input. sweepTx := ctx.receiveTx() + require.Len(t, sweepTx.TxIn, 1) - if len(sweepTx.TxIn) != 1 { - t.Fatal("expected sweep to only sweep the one remaining output") - } + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan2, nil) ctx.finish(1) @@ -649,8 +932,10 @@ func testRemoteSpend(t *testing.T, postSweep bool) { ctx.finish(2) select { - case <-resultChan2: - t.Fatalf("no result expected for error input") + case r := <-resultChan2: + require.NoError(t, r.Err) + require.Equal(t, r.Tx.TxHash(), tx.TxHash()) + default: } } @@ -662,26 +947,58 @@ func TestIdempotency(t *testing.T) { ctx := createSweeperTestContext(t) input := spendableInputs[0] + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resultChan2, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + sweepTx := ctx.receiveTx() - ctx.receiveTx() + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) resultChan3, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Spend the input of the sweep tx. ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan1, nil) ctx.expectResult(resultChan2, nil) ctx.expectResult(resultChan3, nil) @@ -692,9 +1009,7 @@ func TestIdempotency(t *testing.T) { // Because the sweeper kept track of all of its sweep txes, it will // recognize the spend as its own. resultChan4, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.expectResult(resultChan4, nil) // Timer is still running, but spend notification was delivered before @@ -717,25 +1032,78 @@ func TestRestart(t *testing.T) { // Sweep input and expect sweep tx. input1 := spendableInputs[0] + + // Mock the Broadcast method to succeed. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + _, err := ctx.sweeper.SweepInput(input1, defaultFeePref) require.NoError(t, err) - ctx.receiveTx() + sweepTx1 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) // Restart sweeper. ctx.restartSweeper() // Simulate other subsystem (e.g. contract resolver) re-offering inputs. spendChan1, err := ctx.sweeper.SweepInput(input1, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) input2 := spendableInputs[1] + + // Mock the Broadcast method to succeed. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + spendChan2, err := ctx.sweeper.SweepInput(input2, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Spend inputs of sweep txes and verify that spend channels signal // spends. @@ -754,10 +1122,27 @@ func TestRestart(t *testing.T) { // Timer tick should trigger republishing a sweep for the remaining // input. - ctx.receiveTx() + sweepTx2 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + select { case result := <-spendChan2: if result.Err != nil { @@ -778,51 +1163,104 @@ func TestRestart(t *testing.T) { func TestRestartRemoteSpend(t *testing.T) { ctx := createSweeperTestContext(t) - // Sweep input. + // Get testing inputs. input1 := spendableInputs[0] + input2 := spendableInputs[1] + + // Create a fake sweep tx that spends the second input as the first + // will be spent by the remote. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + _, err := ctx.sweeper.SweepInput(input1, defaultFeePref) require.NoError(t, err) // Sweep another input. - input2 := spendableInputs[1] _, err = ctx.sweeper.SweepInput(input2, defaultFeePref) require.NoError(t, err) sweepTx := ctx.receiveTx() + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + // Restart sweeper. ctx.restartSweeper() - // Replace the sweep tx with a remote tx spending input 1. + // Replace the sweep tx with a remote tx spending input 2. ctx.backend.deleteUnconfirmed(sweepTx.TxHash()) remoteTx := &wire.MsgTx{ TxIn: []*wire.TxIn{ - { - PreviousOutPoint: *(input2.OutPoint()), - }, + {PreviousOutPoint: *input1.OutPoint()}, }, } - if err := ctx.backend.publishTransaction(remoteTx); err != nil { - t.Fatal(err) - } + err = ctx.backend.publishTransaction(remoteTx) + require.NoError(t, err) // Mine remote spending tx. ctx.backend.mine() + // Mock the Broadcast method to succeed. + bumpResultChan = make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + // Simulate other subsystem (e.g. contract resolver) re-offering input - // 0. - spendChan, err := ctx.sweeper.SweepInput(input1, defaultFeePref) - if err != nil { - t.Fatal(err) - } + // 2. + spendChan, err := ctx.sweeper.SweepInput(input2, defaultFeePref) + require.NoError(t, err) // Expect sweeper to construct a new tx, because input 1 was spend // remotely. - ctx.receiveTx() + sweepTx = ctx.receiveTx() ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(spendChan, nil) ctx.finish(1) @@ -835,11 +1273,40 @@ func TestRestartConfirmed(t *testing.T) { // Sweep input. input := spendableInputs[0] - if _, err := ctx.sweeper.SweepInput(input, defaultFeePref); err != nil { - t.Fatal(err) - } - ctx.receiveTx() + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + _, err := ctx.sweeper.SweepInput(input, defaultFeePref) + require.NoError(t, err) + + sweepTx := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) // Restart sweeper. ctx.restartSweeper() @@ -847,9 +1314,18 @@ func TestRestartConfirmed(t *testing.T) { // Mine the sweep tx. ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + // Simulate other subsystem (e.g. contract resolver) re-offering input // 0. spendChan, err := ctx.sweeper.SweepInput(input, defaultFeePref) + require.NoError(t, err) if err != nil { t.Fatal(err) } @@ -864,29 +1340,96 @@ func TestRestartConfirmed(t *testing.T) { func TestRetry(t *testing.T) { ctx := createSweeperTestContext(t) - resultChan0, err := ctx.sweeper.SweepInput( - spendableInputs[0], defaultFeePref, - ) - if err != nil { - t.Fatal(err) - } + inp0 := spendableInputs[0] + inp1 := spendableInputs[1] + + // Mock the Broadcast method to succeed. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *inp0.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + resultChan0, err := ctx.sweeper.SweepInput(inp0, defaultFeePref) + require.NoError(t, err) // We expect a sweep to be published. - ctx.receiveTx() + sweepTx1 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *inp1.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() // Offer a fresh input. - resultChan1, err := ctx.sweeper.SweepInput( - spendableInputs[1], defaultFeePref, - ) - if err != nil { - t.Fatal(err) - } + resultChan1, err := ctx.sweeper.SweepInput(inp1, defaultFeePref) + require.NoError(t, err) // A single tx is expected to be published. - ctx.receiveTx() + sweepTx2 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan0, nil) ctx.expectResult(resultChan1, nil) @@ -912,44 +1455,105 @@ func TestDifferentFeePreferences(t *testing.T) { ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate input1 := spendableInputs[0] + input2 := spendableInputs[1] + input3 := spendableInputs[2] + + // Mock the Broadcast method to succeed on the first sweep. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input3.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput( input1, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input2 := spendableInputs[1] + require.NoError(t, err) + resultChan2, err := ctx.sweeper.SweepInput( input2, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input3 := spendableInputs[2] + require.NoError(t, err) + resultChan3, err := ctx.sweeper.SweepInput( input3, Params{Fee: lowFeePref}, ) - if err != nil { - t.Fatal(err) - } - - // Generate the same type of sweep script that was used for weight - // estimation. - changePk, err := ctx.sweeper.cfg.GenSweepScript() require.NoError(t, err) - // The first transaction broadcast should be the one spending the higher - // fee rate inputs. + // The first transaction broadcast should be the one spending the + // higher fee rate inputs. sweepTx1 := ctx.receiveTx() - assertTxFeeRate(t, &sweepTx1, highFeeRate, changePk, input1, input2) // The second should be the one spending the lower fee rate inputs. sweepTx2 := ctx.receiveTx() - assertTxFeeRate(t, &sweepTx2, lowFeeRate, changePk, input3) + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) // With the transactions broadcast, we'll mine a block to so that the // result is delivered to each respective client. ctx.backend.mine() + + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + resultChans := []chan Result{resultChan1, resultChan2, resultChan3} for _, resultChan := range resultChans { ctx.expectResult(resultChan, nil) @@ -983,37 +1587,105 @@ func TestPendingInputs(t *testing.T) { ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate input1 := spendableInputs[0] + input2 := spendableInputs[1] + input3 := spendableInputs[2] + + // Mock the Broadcast method to succeed on the first sweep. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input3.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput( input1, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input2 := spendableInputs[1] + require.NoError(t, err) + _, err = ctx.sweeper.SweepInput( input2, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input3 := spendableInputs[2] + require.NoError(t, err) + resultChan3, err := ctx.sweeper.SweepInput( input3, Params{Fee: lowFeePref}, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // We should expect to see all inputs pending. ctx.assertPendingInputs(input1, input2, input3) // We should expect to see both sweep transactions broadcast - one for // the higher feerate, the other for the lower. - ctx.receiveTx() - ctx.receiveTx() + sweepTx1 := ctx.receiveTx() + sweepTx2 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) // Mine these txns, and we should expect to see the results delivered. ctx.backend.mine() + + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan1, nil) ctx.expectResult(resultChan3, nil) ctx.assertPendingInputs() @@ -1025,6 +1697,8 @@ func TestPendingInputs(t *testing.T) { // request for an input it is currently attempting to sweep. When sweeping the // input with the higher fee rate, a replacement transaction is created. func TestBumpFeeRBF(t *testing.T) { + t.Skip("fix me") + ctx := createSweeperTestContext(t) lowFeePref := FeeEstimateInfo{ConfTarget: 144} @@ -1095,6 +1769,88 @@ func TestBumpFeeRBF(t *testing.T) { func TestExclusiveGroup(t *testing.T) { ctx := createSweeperTestContext(t) + input1 := spendableInputs[0] + input2 := spendableInputs[1] + input3 := spendableInputs[2] + + // Mock the Broadcast method to succeed on the first sweep. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the third sweep. + bumpResultChan3 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan3, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input3.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan3 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + // Sweep three inputs in the same exclusive group. var results []chan Result for i := 0; i < 3; i++ { @@ -1105,32 +1861,45 @@ func TestExclusiveGroup(t *testing.T) { ExclusiveGroup: &exclusiveGroup, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) results = append(results, result) } // We expect all inputs to be published in separate transactions, even // though they share the same fee preference. - for i := 0; i < 3; i++ { - sweepTx := ctx.receiveTx() - if len(sweepTx.TxOut) != 1 { - t.Fatal("expected a single tx out in the sweep tx") - } + sweepTx1 := ctx.receiveTx() + require.Len(t, sweepTx1.TxIn, 1) - // Remove all txes except for the one that sweeps the first - // input. This simulates the sweeps being conflicting. - if sweepTx.TxIn[0].PreviousOutPoint != - *spendableInputs[0].OutPoint() { + sweepTx2 := ctx.receiveTx() + sweepTx3 := ctx.receiveTx() - ctx.backend.deleteUnconfirmed(sweepTx.TxHash()) - } - } + // Remove all txes except for the one that sweeps the first + // input. This simulates the sweeps being conflicting. + ctx.backend.deleteUnconfirmed(sweepTx2.TxHash()) + ctx.backend.deleteUnconfirmed(sweepTx3.TxHash()) + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 3) // Mine the first sweep tx. ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxFailed, + Tx: &sweepTx2, + } + bumpResultChan2 <- &BumpResult{ + Event: TxFailed, + Tx: &sweepTx3, + } + // Expect the first input to be swept by the confirmed sweep tx. result0 := <-results[0] if result0.Err != nil { @@ -1150,9 +1919,11 @@ func TestExclusiveGroup(t *testing.T) { } } -// TestCpfp tests that the sweeper spends cpfp inputs at a fee rate that exceeds -// the parent tx fee rate. +// TestCpfp tests that the sweeper spends cpfp inputs at a fee rate that +// exceeds the parent tx fee rate. func TestCpfp(t *testing.T) { + t.Skip("fix me") + ctx := createSweeperTestContext(t) ctx.estimator.updateFees(1000, chainfee.FeePerKwFloor) @@ -1308,8 +2079,10 @@ func TestLockTimes(t *testing.T) { // Sweep 8 inputs, using 4 different lock times. var ( - results []chan Result - inputs = make(map[wire.OutPoint]input.Input) + results []chan Result + inputs = make(map[wire.OutPoint]input.Input) + clusters = make(map[uint32][]input.Input) + bumpResultChans = make([]chan *BumpResult, 0, 4) ) for i := 0; i < numSweeps*2; i++ { lt := uint32(10 + (i % numSweeps)) @@ -1318,53 +2091,84 @@ func TestLockTimes(t *testing.T) { locktime: <, } - result, err := ctx.sweeper.SweepInput( - inp, Params{ - Fee: FeeEstimateInfo{ConfTarget: 6}, - }, - ) - if err != nil { - t.Fatal(err) - } - results = append(results, result) - op := inp.OutPoint() inputs[*op] = inp + + cluster, ok := clusters[lt] + if !ok { + cluster = make([]input.Input, 0) + } + cluster = append(cluster, inp) + clusters[lt] = cluster } - // We also add 3 regular inputs that don't require any specific lock - // time. for i := 0; i < 3; i++ { inp := spendableInputs[i+numSweeps*2] + inputs[*inp.OutPoint()] = inp + + lt := uint32(10 + (i % numSweeps)) + clusters[lt] = append(clusters[lt], inp) + } + + for lt, cluster := range clusters { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{}, + LockTime: lt, + } + + // Append the inputs. + for _, inp := range cluster { + txIn := &wire.TxIn{ + PreviousOutPoint: *inp.OutPoint(), + } + tx.TxIn = append(tx.TxIn, txIn) + } + + // Mock the Broadcast method to succeed on current sweep. + bumpResultChan := make(chan *BumpResult, 1) + bumpResultChans = append(bumpResultChans, bumpResultChan) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need + // to manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them + // will mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + } + + // Make all the sweeps. + for _, inp := range inputs { result, err := ctx.sweeper.SweepInput( inp, Params{ Fee: FeeEstimateInfo{ConfTarget: 6}, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) results = append(results, result) - - op := inp.OutPoint() - inputs[*op] = inp } // Check the sweeps transactions, ensuring all inputs are there, and // all the locktimes are satisfied. + sweepTxes := make([]wire.MsgTx, 0, numSweeps) for i := 0; i < numSweeps; i++ { sweepTx := ctx.receiveTx() - if len(sweepTx.TxOut) != 1 { - t.Fatal("expected a single tx out in the sweep tx") - } + sweepTxes = append(sweepTxes, sweepTx) for _, txIn := range sweepTx.TxIn { op := txIn.PreviousOutPoint inp, ok := inputs[op] - if !ok { - t.Fatalf("Unexpected outpoint: %v", op) - } + require.True(t, ok) delete(inputs, op) @@ -1375,25 +2179,33 @@ func TestLockTimes(t *testing.T) { continue } - if lt != sweepTx.LockTime { - t.Fatalf("Input required locktime %v, sweep "+ - "tx had locktime %v", lt, sweepTx.LockTime) - } + require.EqualValues(t, lt, sweepTx.LockTime) } } - // The should be no inputs not foud in any of the sweeps. - if len(inputs) != 0 { - t.Fatalf("had unsweeped inputs: %v", inputs) - } + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 4) - // Mine the first sweeps + // Mine the sweeps. ctx.backend.mine() + for i, bumpResultChan := range bumpResultChans { + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTxes[i], + FeeRate: 10, + Fee: 100, + } + } + + // The should be no inputs not foud in any of the sweeps. + require.Empty(t, inputs) + // Results should all come back. - for i := range results { + for i, resultChan := range results { select { - case result := <-results[i]: + case result := <-resultChan: require.NoError(t, result.Err) case <-time.After(1 * time.Second): t.Fatalf("result %v did not come back", i) @@ -1401,9 +2213,11 @@ func TestLockTimes(t *testing.T) { } } -// TestRequiredTxOuts checks that inputs having a required TxOut gets swept with -// sweep transactions paying into these outputs. +// TestRequiredTxOuts checks that inputs having a required TxOut gets swept +// with sweep transactions paying into these outputs. func TestRequiredTxOuts(t *testing.T) { + t.Skip("fix me") + // Create some test inputs and locktime vars. var inputs []*input.BaseInput for i := 0; i < 20; i++ { From 15872dcd6b4052206d20ade0b076d8b9196961b6 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sat, 16 Mar 2024 07:59:06 +0800 Subject: [PATCH 08/19] sweep: remove RBF related tests As there will be dedicated new tests for them. --- sweep/sweeper_test.go | 596 ------------------------------------------ 1 file changed, 596 deletions(-) diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index f18cc49177..ee143e3d31 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1693,78 +1693,6 @@ func TestPendingInputs(t *testing.T) { ctx.finish(1) } -// TestBumpFeeRBF ensures that the UtxoSweeper can properly handle a fee bump -// request for an input it is currently attempting to sweep. When sweeping the -// input with the higher fee rate, a replacement transaction is created. -func TestBumpFeeRBF(t *testing.T) { - t.Skip("fix me") - - ctx := createSweeperTestContext(t) - - lowFeePref := FeeEstimateInfo{ConfTarget: 144} - lowFeeRate := chainfee.FeePerKwFloor - ctx.estimator.blocksToFee[lowFeePref.ConfTarget] = lowFeeRate - - // We'll first try to bump the fee of an output currently unknown to the - // UtxoSweeper. Doing so should result in a lnwallet.ErrNotMine error. - _, err := ctx.sweeper.UpdateParams( - wire.OutPoint{}, ParamsUpdate{Fee: lowFeePref}, - ) - if err != lnwallet.ErrNotMine { - t.Fatalf("expected error lnwallet.ErrNotMine, got \"%v\"", err) - } - - // We'll then attempt to sweep an input, which we'll use to bump its fee - // later on. - input := createTestInput( - btcutil.SatoshiPerBitcoin, input.CommitmentTimeLock, - ) - sweepResult, err := ctx.sweeper.SweepInput( - &input, Params{Fee: lowFeePref}, - ) - if err != nil { - t.Fatal(err) - } - - // Generate the same type of change script used so we can have accurate - // weight estimation. - changePk, err := ctx.sweeper.cfg.GenSweepScript() - require.NoError(t, err) - - // Ensure that a transaction is broadcast with the lower fee preference. - lowFeeTx := ctx.receiveTx() - assertTxFeeRate(t, &lowFeeTx, lowFeeRate, changePk, &input) - - // We'll then attempt to bump its fee rate. - highFeePref := FeeEstimateInfo{ConfTarget: 6} - highFeeRate := DefaultMaxFeeRate.FeePerKWeight() - ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate - - // We should expect to see an error if a fee preference isn't provided. - _, err = ctx.sweeper.UpdateParams(*input.OutPoint(), ParamsUpdate{ - Fee: &FeeEstimateInfo{}, - }) - if err != ErrNoFeePreference { - t.Fatalf("expected ErrNoFeePreference, got %v", err) - } - - bumpResult, err := ctx.sweeper.UpdateParams( - *input.OutPoint(), ParamsUpdate{Fee: highFeePref}, - ) - require.NoError(t, err, "unable to bump input's fee") - - // A higher fee rate transaction should be immediately broadcast. - highFeeTx := ctx.receiveTx() - assertTxFeeRate(t, &highFeeTx, highFeeRate, changePk, &input) - - // We'll finish our test by mining the sweep transaction. - ctx.backend.mine() - ctx.expectResult(sweepResult, nil) - ctx.expectResult(bumpResult, nil) - - ctx.finish(1) -} - // TestExclusiveGroup tests the sweeper exclusive group functionality. func TestExclusiveGroup(t *testing.T) { ctx := createSweeperTestContext(t) @@ -1919,71 +1847,6 @@ func TestExclusiveGroup(t *testing.T) { } } -// TestCpfp tests that the sweeper spends cpfp inputs at a fee rate that -// exceeds the parent tx fee rate. -func TestCpfp(t *testing.T) { - t.Skip("fix me") - - ctx := createSweeperTestContext(t) - - ctx.estimator.updateFees(1000, chainfee.FeePerKwFloor) - - // Offer an input with an unconfirmed parent tx to the sweeper. The - // parent tx pays 3000 sat/kw. - hash := chainhash.Hash{1} - input := input.MakeBaseInput( - &wire.OutPoint{Hash: hash}, - input.CommitmentTimeLock, - &input.SignDescriptor{ - Output: &wire.TxOut{ - Value: 330, - }, - KeyDesc: keychain.KeyDescriptor{ - PubKey: testPubKey, - }, - }, - 0, - &input.TxInfo{ - Weight: 300, - Fee: 900, - }, - ) - - feePref := FeeEstimateInfo{ConfTarget: 6} - result, err := ctx.sweeper.SweepInput( - &input, Params{Fee: feePref, Force: true}, - ) - require.NoError(t, err) - - // Increase the fee estimate to above the parent tx fee rate. - ctx.estimator.updateFees(5000, chainfee.FeePerKwFloor) - - // Signal a new block. This is a trigger for the sweeper to refresh fee - // estimates. - ctx.notifier.NotifyEpoch(1000) - - // Now we do expect a sweep transaction to be published with our input - // and an attached wallet utxo. - tx := ctx.receiveTx() - require.Len(t, tx.TxIn, 2) - require.Len(t, tx.TxOut, 1) - - // As inputs we have 10000 sats from the wallet and 330 sats from the - // cpfp input. The sweep tx is weight expected to be 759 units. There is - // an additional 300 weight units from the parent to include in the - // package, making a total of 1059. At 5000 sat/kw, the required fee for - // the package is 5295 sats. The parent already paid 900 sats, so there - // is 4395 sat remaining to be paid. The expected output value is - // therefore: 1_000_000 + 330 - 4395 = 995 935. - require.Equal(t, int64(995_935), tx.TxOut[0].Value) - - // Mine the tx and assert that the result is passed back. - ctx.backend.mine() - ctx.expectResult(result, nil) - - ctx.finish(1) -} - type testInput struct { *input.BaseInput @@ -2213,465 +2076,6 @@ func TestLockTimes(t *testing.T) { } } -// TestRequiredTxOuts checks that inputs having a required TxOut gets swept -// with sweep transactions paying into these outputs. -func TestRequiredTxOuts(t *testing.T) { - t.Skip("fix me") - - // Create some test inputs and locktime vars. - var inputs []*input.BaseInput - for i := 0; i < 20; i++ { - input := createTestInput( - int64(btcutil.SatoshiPerBitcoin+i*500), - input.CommitmentTimeLock, - ) - - inputs = append(inputs, &input) - } - - locktime1 := uint32(51) - locktime2 := uint32(52) - locktime3 := uint32(53) - - aPkScript := make([]byte, input.P2WPKHSize) - aPkScript[0] = 'a' - - bPkScript := make([]byte, input.P2WSHSize) - bPkScript[0] = 'b' - - cPkScript := make([]byte, input.P2PKHSize) - cPkScript[0] = 'c' - - dPkScript := make([]byte, input.P2SHSize) - dPkScript[0] = 'd' - - ePkScript := make([]byte, input.UnknownWitnessSize) - ePkScript[0] = 'e' - - fPkScript := make([]byte, input.P2WSHSize) - fPkScript[0] = 'f' - - testCases := []struct { - name string - inputs []*testInput - assertSweeps func(*testing.T, map[wire.OutPoint]*testInput, - []*wire.MsgTx) - }{ - { - // Single input with a required TX out that is smaller. - // We expect a change output to be added. - name: "single input, leftover change", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: 100000, - }, - }, - }, - - // Since the required output value is small, we expect - // the rest after fees to go into a change output. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 1, len(tx.TxIn)) - - // We should have two outputs, the required - // output must be the first one. - require.Equal(t, 2, len(tx.TxOut)) - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal(t, int64(100000), out.Value) - }, - }, - { - // An input committing to a slightly smaller output, so - // it will pay its own fees. - name: "single input, no change", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - - // Fee will be about 5340 sats. - // Subtract a bit more to - // ensure no dust change output - // is manifested. - Value: inputs[0].SignDesc().Output.Value - 6300, - }, - }, - }, - - // We expect this single input/output pair. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 1, len(tx.TxIn)) - - require.Equal(t, 1, len(tx.TxOut)) - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal( - t, - inputs[0].SignDesc().Output.Value-6300, - out.Value, - ) - }, - }, - { - // Two inputs, where the first one required no tx out. - name: "two inputs, one with required tx out", - inputs: []*testInput{ - { - - // We add a normal, non-requiredTxOut - // input. We use test input 10, to make - // sure this has a higher yield than - // the other input, and will be - // attempted added first to the sweep - // tx. - BaseInput: inputs[10], - }, - { - // The second input requires a TxOut. - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - }, - - // We expect the inputs to have been reordered. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 2, len(tx.TxIn)) - require.Equal(t, 2, len(tx.TxOut)) - - // The required TxOut should be the first one. - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal( - t, inputs[0].SignDesc().Output.Value, - out.Value, - ) - - // The first input should be the one having the - // required TxOut. - require.Len(t, tx.TxIn, 2) - require.Equal( - t, inputs[0].OutPoint(), - &tx.TxIn[0].PreviousOutPoint, - ) - - // Second one is the one without a required tx - // out. - require.Equal( - t, inputs[10].OutPoint(), - &tx.TxIn[1].PreviousOutPoint, - ) - }, - }, - - { - // An input committing to an output of equal value, just - // add input to pay fees. - name: "single input, extra fee input", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - }, - - // We expect an extra input and output. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 2, len(tx.TxIn)) - - require.Equal(t, 2, len(tx.TxOut)) - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal( - t, inputs[0].SignDesc().Output.Value, - out.Value, - ) - }, - }, - { - // Three inputs added, should be combined into a single - // sweep. - name: "three inputs", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[1], - reqTxOut: &wire.TxOut{ - PkScript: bPkScript, - Value: inputs[1].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[2], - reqTxOut: &wire.TxOut{ - PkScript: cPkScript, - Value: inputs[2].SignDesc().Output.Value, - }, - }, - }, - - // We expect an extra input and output to pay fees. - assertSweeps: func(t *testing.T, - testInputs map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 4, len(tx.TxIn)) - require.Equal(t, 4, len(tx.TxOut)) - - // The inputs and outputs must be in the same - // order. - for i, in := range tx.TxIn { - // Last one is the change input/output - // pair, so we'll skip it. - if i == 3 { - continue - } - - // Get this input to ensure the output - // on index i coresponsd to this one. - inp := testInputs[in.PreviousOutPoint] - require.NotNil(t, inp) - - require.Equal( - t, tx.TxOut[i].Value, - inp.SignDesc().Output.Value, - ) - } - }, - }, - { - // Six inputs added, which 3 different locktimes. - // Should result in 3 sweeps. - name: "six inputs", - inputs: []*testInput{ - { - BaseInput: inputs[0], - locktime: &locktime1, - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[1], - locktime: &locktime1, - reqTxOut: &wire.TxOut{ - PkScript: bPkScript, - Value: inputs[1].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[2], - locktime: &locktime2, - reqTxOut: &wire.TxOut{ - PkScript: cPkScript, - Value: inputs[2].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[3], - locktime: &locktime2, - reqTxOut: &wire.TxOut{ - PkScript: dPkScript, - Value: inputs[3].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[4], - locktime: &locktime3, - reqTxOut: &wire.TxOut{ - PkScript: ePkScript, - Value: inputs[4].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[5], - locktime: &locktime3, - reqTxOut: &wire.TxOut{ - PkScript: fPkScript, - Value: inputs[5].SignDesc().Output.Value, - }, - }, - }, - - // We expect three sweeps, each having two of our - // inputs, one extra input and output to pay fees. - assertSweeps: func(t *testing.T, - testInputs map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 3, len(txs)) - - for _, tx := range txs { - require.Equal(t, 3, len(tx.TxIn)) - require.Equal(t, 3, len(tx.TxOut)) - - // The inputs and outputs must be in - // the same order. - for i, in := range tx.TxIn { - // Last one is the change - // output, so we'll skip it. - if i == 2 { - continue - } - - // Get this input to ensure the - // output on index i coresponsd - // to this one. - inp := testInputs[in.PreviousOutPoint] - require.NotNil(t, inp) - - require.Equal( - t, tx.TxOut[i].Value, - inp.SignDesc().Output.Value, - ) - - // Check that the locktimes are - // kept intact. - require.Equal( - t, tx.LockTime, - *inp.locktime, - ) - } - } - }, - }, - } - - for _, testCase := range testCases { - testCase := testCase - - t.Run(testCase.name, func(t *testing.T) { - ctx := createSweeperTestContext(t) - - // We increase the number of max inputs to a tx so that - // won't impact our test. - ctx.sweeper.cfg.MaxInputsPerTx = 100 - - // Sweep all test inputs. - var ( - inputs = make(map[wire.OutPoint]*testInput) - results = make(map[wire.OutPoint]chan Result) - ) - for _, inp := range testCase.inputs { - result, err := ctx.sweeper.SweepInput( - inp, Params{ - Fee: FeeEstimateInfo{ - ConfTarget: 6, - }, - }, - ) - if err != nil { - t.Fatal(err) - } - - op := inp.OutPoint() - results[*op] = result - inputs[*op] = inp - } - - // Send a new block epoch to trigger the sweeper to - // sweep the inputs. - ctx.notifier.NotifyEpoch(ctx.sweeper.currentHeight + 1) - - // Check the sweeps transactions, ensuring all inputs - // are there, and all the locktimes are satisfied. - var sweeps []*wire.MsgTx - Loop: - for { - select { - case tx := <-ctx.publishChan: - sweeps = append(sweeps, &tx) - case <-time.After(200 * time.Millisecond): - break Loop - } - } - - // Mine the sweeps. - ctx.backend.mine() - - // Results should all come back. - for _, resultChan := range results { - result := <-resultChan - if result.Err != nil { - t.Fatalf("expected input to be "+ - "swept: %v", result.Err) - } - } - - // Assert the transactions are what we expect. - testCase.assertSweeps(t, inputs, sweeps) - - // Finally we assert that all our test inputs were part - // of the sweeps, and that they were signed correctly. - sweptInputs := make(map[wire.OutPoint]struct{}) - for _, sweep := range sweeps { - swept := assertSignedIndex(t, sweep, inputs) - for op := range swept { - if _, ok := sweptInputs[op]; ok { - t.Fatalf("outpoint %v part of "+ - "previous sweep", op) - } - - sweptInputs[op] = struct{}{} - } - } - - require.Equal(t, len(inputs), len(sweptInputs)) - for op := range sweptInputs { - _, ok := inputs[op] - if !ok { - t.Fatalf("swept input %v not part of "+ - "test inputs", op) - } - } - }) - } -} - // TestSweeperShutdownHandling tests that we notify callers when the sweeper // cannot handle requests since it's in the process of shutting down. func TestSweeperShutdownHandling(t *testing.T) { From 7af5bbafca1dd07b62243f8eabd5b854b96647d1 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 17 Jan 2024 17:49:09 +0800 Subject: [PATCH 09/19] sweep: remove `FeeRate()` from `InputSet` interface As shown in the following commit, fee rate calculation will now be handled by the fee bumper, hence there's no need to expose this on `InputSet` interface. --- sweep/tx_input_set.go | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index 0354c6c1d4..d2e47c57b2 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -51,9 +51,6 @@ type InputSet interface { // Inputs returns the set of inputs that should be used to create a tx. Inputs() []input.Input - // FeeRate returns the fee rate that should be used for the tx. - FeeRate() chainfee.SatPerKWeight - // AddWalletInputs adds wallet inputs to the set until a non-dust // change output can be made. Return an error if there are not enough // wallet inputs. @@ -208,11 +205,6 @@ func (t *txInputSet) DeadlineHeight() fn.Option[int32] { return fn.None[int32]() } -// FeeRate returns the fee rate that should be used for the tx. -func (t *txInputSet) FeeRate() chainfee.SatPerKWeight { - return t.feeRate -} - // NeedWalletInput returns true if the input set needs more wallet inputs. func (t *txInputSet) NeedWalletInput() bool { return !t.enoughInput() @@ -807,12 +799,3 @@ func (b *BudgetInputSet) Inputs() []input.Input { return inputs } - -// FeeRate returns the fee rate that should be used for the tx. -// -// NOTE: part of the InputSet interface. -// -// TODO(yy): will be removed once fee bumper is implemented. -func (b *BudgetInputSet) FeeRate() chainfee.SatPerKWeight { - return 0 -} From f354b65fbf6ed73d78eaade5d748fe1e68fb147b Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 25 Jan 2024 02:04:43 +0800 Subject: [PATCH 10/19] sweep: add `FeeFunction` interface and a linear implementation This commit adds a new interface, `FeeFunction`, to deal with calculating fee rates. In addition, a simple linear function is implemented, hence `LinearFeeFunction`, which will be used to calculate fee rates when bumping fees. Check #4215 for other type of fee functions that can be implemented. --- sweep/fee_function.go | 265 +++++++++++++++++++++++++++++++++++++ sweep/fee_function_test.go | 247 ++++++++++++++++++++++++++++++++++ 2 files changed, 512 insertions(+) create mode 100644 sweep/fee_function.go create mode 100644 sweep/fee_function_test.go diff --git a/sweep/fee_function.go b/sweep/fee_function.go new file mode 100644 index 0000000000..db3fb7851c --- /dev/null +++ b/sweep/fee_function.go @@ -0,0 +1,265 @@ +package sweep + +import ( + "errors" + "fmt" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +var ( + // ErrMaxPosition is returned when trying to increase the position of + // the fee function while it's already at its max. + ErrMaxPosition = errors.New("position already at max") +) + +// FeeFunction defines an interface that is used to calculate fee rates for +// transactions. It's expected the implementations use three params, the +// starting fee rate, the ending fee rate, and number of blocks till deadline +// block height, to build an algorithm to calculate the fee rate based on the +// current block height. +type FeeFunction interface { + // FeeRate returns the current fee rate calculated by the fee function. + FeeRate() chainfee.SatPerKWeight + + // Increment increases the fee rate by one step. The definition of one + // step is up to the implementation. After calling this method, it's + // expected to change the state of the fee function such that calling + // `FeeRate` again will return the increased value. + // + // It returns a boolean to indicate whether the fee rate is increased, + // as fee bump should not be attempted if the increased fee rate is not + // greater than the current fee rate, which may happen if the algorithm + // gives the same fee rates at two positions. + // + // An error is returned when the max fee rate is reached. + // + // NOTE: we intentionally don't return the new fee rate here, so both + // the implementation and the caller are aware of the state change. + Increment() (bool, error) + + // IncreaseFeeRate increases the fee rate to the new position + // calculated using (width - confTarget). It returns a boolean to + // indicate whether the fee rate is increased, and an error if the + // position is greater than the width. + // + // NOTE: this method is provided to allow the caller to increase the + // fee rate based on a conf target without taking care of the fee + // function's current state (position). + IncreaseFeeRate(confTarget uint32) (bool, error) +} + +// LinearFeeFunction implements the FeeFunction interface with a linear +// function: +// +// feeRate = startingFeeRate + position * delta. +// - width: deadlineBlockHeight - startingBlockHeight +// - delta: (endingFeeRate - startingFeeRate) / width +// - position: currentBlockHeight - startingBlockHeight +// +// The fee rate will be capped at endingFeeRate. +// +// TODO(yy): implement more functions specified here: +// - https://github.com/lightningnetwork/lnd/issues/4215 +type LinearFeeFunction struct { + // startingFeeRate specifies the initial fee rate to begin with. + startingFeeRate chainfee.SatPerKWeight + + // endingFeeRate specifies the max allowed fee rate. + endingFeeRate chainfee.SatPerKWeight + + // currentFeeRate specifies the current calculated fee rate. + currentFeeRate chainfee.SatPerKWeight + + // width is the number of blocks between the starting block height + // and the deadline block height. + width uint32 + + // position is the number of blocks between the starting block height + // and the current block height. + position uint32 + + // deltaFeeRate is the fee rate increase per block. + deltaFeeRate chainfee.SatPerKWeight + + // estimator is the fee estimator used to estimate the fee rate. We use + // it to get the initial fee rate and, use it as a benchmark to decide + // whether we want to used the estimated fee rate or the calculated fee + // rate based on different strategies. + estimator chainfee.Estimator +} + +// Compile-time check to ensure LinearFeeFunction satisfies the FeeFunction. +var _ FeeFunction = (*LinearFeeFunction)(nil) + +// NewLinearFeeFunction creates a new linear fee function and initializes it +// with a starting fee rate which is an estimated value returned from the fee +// estimator using the initial conf target. +func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight, confTarget uint32, + estimator chainfee.Estimator) (*LinearFeeFunction, error) { + + // Sanity check conf target. + if confTarget == 0 { + return nil, fmt.Errorf("width must be greater than zero") + } + + l := &LinearFeeFunction{ + endingFeeRate: maxFeeRate, + width: confTarget, + estimator: estimator, + } + + // Estimate the initial fee rate. + // + // NOTE: estimateFeeRate guarantees the returned fee rate is capped by + // the ending fee rate, so we don't need to worry about overpay. + start, err := l.estimateFeeRate(confTarget) + if err != nil { + return nil, fmt.Errorf("estimate initial fee rate: %w", err) + } + + // Calculate how much fee rate should be increased per block. + end := l.endingFeeRate + delta := btcutil.Amount(end - start).MulF64(1 / float64(confTarget)) + + // We only allow the delta to be zero if the width is one - when the + // delta is zero, it means the starting and ending fee rates are the + // same, which means there's nothing to increase, so any width greater + // than 1 doesn't provide any utility. This could happen when the + // sweeper is offered to sweep an input that has passed its deadline. + if delta == 0 && l.width != 1 { + return nil, fmt.Errorf("fee rate delta is zero") + } + + // Attach the calculated values to the fee function. + l.startingFeeRate = start + l.currentFeeRate = start + l.deltaFeeRate = chainfee.SatPerKWeight(delta) + + log.Debugf("Linear fee function initialized with startingFeeRate=%v, "+ + "endingFeeRate=%v, width=%v, delta=%v", start, end, + confTarget, l.deltaFeeRate) + + return l, nil +} + +// FeeRate returns the current fee rate. +// +// NOTE: part of the FeeFunction interface. +func (l *LinearFeeFunction) FeeRate() chainfee.SatPerKWeight { + return l.currentFeeRate +} + +// Increment increases the fee rate by one position, returns a boolean to +// indicate whether the fee rate was increased, and an error if the position is +// greater than the width. The increased fee rate will be set as the current +// fee rate, and the internal position will be incremented. +// +// NOTE: this method will change the state of the fee function as it increases +// its current fee rate. +// +// NOTE: part of the FeeFunction interface. +func (l *LinearFeeFunction) Increment() (bool, error) { + return l.increaseFeeRate(l.position + 1) +} + +// IncreaseFeeRate calculate a new position using the given conf target, and +// increases the fee rate to the new position by calling the Increment method. +// +// NOTE: this method will change the state of the fee function as it increases +// its current fee rate. +// +// NOTE: part of the FeeFunction interface. +func (l *LinearFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) { + // If the new position is already at the end, we return an error. + if confTarget == 0 { + return false, ErrMaxPosition + } + + newPosition := uint32(0) + + // Only calculate the new position when the conf target is less than + // the function's width - the width is the initial conf target, and we + // expect the current conf target to decrease over time. However, we + // still allow the supplied conf target to be greater than the width, + // and we won't increase the fee rate in that case. + if confTarget < l.width { + newPosition = l.width - confTarget + log.Tracef("Increasing position from %v to %v", l.position, + newPosition) + } + + if newPosition <= l.position { + log.Tracef("Skipped increase feerate: position=%v, "+ + "newPosition=%v ", l.position, newPosition) + + return false, nil + } + + return l.increaseFeeRate(newPosition) +} + +// increaseFeeRate increases the fee rate by the specified position, returns a +// boolean to indicate whether the fee rate was increased, and an error if the +// position is greater than the width. The increased fee rate will be set as +// the current fee rate, and the internal position will be set to the specified +// position. +// +// NOTE: this method will change the state of the fee function as it increases +// its current fee rate. +func (l *LinearFeeFunction) increaseFeeRate(position uint32) (bool, error) { + // If the new position is already at the end, we return an error. + if l.position >= l.width { + return false, ErrMaxPosition + } + + // Get the old fee rate. + oldFeeRate := l.currentFeeRate + + // Update its internal state. + l.position = position + l.currentFeeRate = l.feeRateAtPosition(position) + + log.Tracef("Fee rate increased from %v to %v at position %v", + oldFeeRate, l.currentFeeRate, l.position) + + return l.currentFeeRate > oldFeeRate, nil +} + +// feeRateAtPosition calculates the fee rate at a given position and caps it at +// the ending fee rate. +func (l *LinearFeeFunction) feeRateAtPosition(p uint32) chainfee.SatPerKWeight { + if p >= l.width { + return l.endingFeeRate + } + + feeRateDelta := btcutil.Amount(l.deltaFeeRate).MulF64(float64(p)) + + feeRate := l.startingFeeRate + chainfee.SatPerKWeight(feeRateDelta) + if feeRate > l.endingFeeRate { + return l.endingFeeRate + } + + return feeRate +} + +// estimateFeeRate asks the fee estimator to estimate the fee rate based on its +// conf target. +func (l *LinearFeeFunction) estimateFeeRate( + confTarget uint32) (chainfee.SatPerKWeight, error) { + + fee := FeeEstimateInfo{ + ConfTarget: confTarget, + } + + // endingFeeRate comes from budget/txWeight, which means the returned + // fee rate will always be capped by this value, hence we don't need to + // worry about overpay. + estimatedFeeRate, err := fee.Estimate(l.estimator, l.endingFeeRate) + if err != nil { + return 0, err + } + + return estimatedFeeRate, nil +} diff --git a/sweep/fee_function_test.go b/sweep/fee_function_test.go new file mode 100644 index 0000000000..e7f80819aa --- /dev/null +++ b/sweep/fee_function_test.go @@ -0,0 +1,247 @@ +package sweep + +import ( + "testing" + + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +// TestLinearFeeFunctionNew tests the NewLinearFeeFunction function. +func TestLinearFeeFunctionNew(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create testing params. + maxFeeRate := chainfee.SatPerKWeight(10000) + estimatedFeeRate := chainfee.SatPerKWeight(500) + confTarget := uint32(6) + + // Assert init fee function with zero conf value returns an error. + f, err := NewLinearFeeFunction(maxFeeRate, 0, estimator) + rt.ErrorContains(err, "width must be greater than zero") + rt.Nil(f) + + // When the fee estimator returns an error, it's returned. + // + // Mock the fee estimator to return an error. + estimator.On("EstimateFeePerKW", confTarget).Return( + chainfee.SatPerKWeight(0), errDummy).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.ErrorIs(err, errDummy) + rt.Nil(f) + + // When the starting feerate is greater than the ending feerate, the + // starting feerate is capped. + // + // Mock the fee estimator to return the fee rate. + smallConf := uint32(1) + estimator.On("EstimateFeePerKW", smallConf).Return( + // The fee rate is greater than the max fee rate. + maxFeeRate+1, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, smallConf, estimator) + rt.NoError(err) + rt.NotNil(f) + + // When the calculated fee rate delta is 0, an error should be returned. + // + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + // The starting fee rate is 1 sat/kw less than the max fee rate. + maxFeeRate-1, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.ErrorContains(err, "fee rate delta is zero") + rt.Nil(f) + + // Check a successfully created fee function. + // + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + estimatedFeeRate, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.NoError(err) + rt.NotNil(f) + + // Assert the internal state. + rt.Equal(estimatedFeeRate, f.startingFeeRate) + rt.Equal(maxFeeRate, f.endingFeeRate) + rt.Equal(estimatedFeeRate, f.currentFeeRate) + rt.NotZero(f.deltaFeeRate) + rt.Equal(confTarget, f.width) +} + +// TestLinearFeeFunctionFeeRateAtPosition checks the expected feerate is +// calculated and returned. +func TestLinearFeeFunctionFeeRateAtPosition(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a fee func which has three positions: + // - position 0: 1000 + // - position 1: 2000 + // - position 2: 3000 + f := &LinearFeeFunction{ + startingFeeRate: 1000, + endingFeeRate: 3000, + position: 0, + deltaFeeRate: 1000, + width: 3, + } + + testCases := []struct { + name string + pos uint32 + expectedFeerate chainfee.SatPerKWeight + }{ + { + name: "position 0", + pos: 0, + expectedFeerate: 1000, + }, + { + name: "position 1", + pos: 1, + expectedFeerate: 2000, + }, + { + name: "position 2", + pos: 2, + expectedFeerate: 3000, + }, + { + name: "position 3", + pos: 3, + expectedFeerate: 3000, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := f.feeRateAtPosition(tc.pos) + rt.Equal(tc.expectedFeerate, result) + }) + } +} + +// TestLinearFeeFunctionIncrement checks the internal state is updated +// correctly when the fee rate is incremented. +func TestLinearFeeFunctionIncrement(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create testing params. These params are chosen so the delta value is + // 100. + maxFeeRate := chainfee.SatPerKWeight(1000) + estimatedFeeRate := chainfee.SatPerKWeight(100) + confTarget := uint32(9) + + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + estimatedFeeRate, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err := NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.NoError(err) + + // We now increase the position from 1 to 9. + for i := uint32(1); i <= confTarget; i++ { + // Increase the fee rate. + increased, err := f.Increment() + rt.NoError(err) + rt.True(increased) + + // Assert the internal state. + rt.Equal(i, f.position) + + delta := chainfee.SatPerKWeight(i * 100) + rt.Equal(estimatedFeeRate+delta, f.currentFeeRate) + + // Check public method returns the expected fee rate. + rt.Equal(estimatedFeeRate+delta, f.FeeRate()) + } + + // Now the position is at 9th, increase it again should give us an + // error. + increased, err := f.Increment() + rt.ErrorIs(err, ErrMaxPosition) + rt.False(increased) +} + +// TestLinearFeeFunctionIncreaseFeeRate checks the internal state is updated +// correctly when the fee rate is increased using conf targets. +func TestLinearFeeFunctionIncreaseFeeRate(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create testing params. These params are chosen so the delta value is + // 100. + maxFeeRate := chainfee.SatPerKWeight(1000) + estimatedFeeRate := chainfee.SatPerKWeight(100) + confTarget := uint32(9) + + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + estimatedFeeRate, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err := NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.NoError(err) + + // If we are increasing the fee rate using the initial conf target, we + // should get a nil error and false. + increased, err := f.IncreaseFeeRate(confTarget) + rt.NoError(err) + rt.False(increased) + + // Test that we are allowed to use a larger conf target. + increased, err = f.IncreaseFeeRate(confTarget + 1) + rt.NoError(err) + rt.False(increased) + + // Test that when we use a conf target of 0, we get an error. + increased, err = f.IncreaseFeeRate(0) + rt.ErrorIs(err, ErrMaxPosition) + rt.False(increased) + + // We now increase the fee rate from conf target 8 to 1 and assert we + // get no error and true. + for i := uint32(1); i < confTarget; i++ { + // Increase the fee rate. + increased, err := f.IncreaseFeeRate(confTarget - i) + rt.NoError(err) + rt.True(increased) + + // Assert the internal state. + rt.Equal(i, f.position) + + delta := chainfee.SatPerKWeight(i * 100) + rt.Equal(estimatedFeeRate+delta, f.currentFeeRate) + + // Check public method returns the expected fee rate. + rt.Equal(estimatedFeeRate+delta, f.FeeRate()) + } +} From 447cf5c30247d5c35023c7affe740616c7c9c7f7 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 29 Feb 2024 03:07:22 +0800 Subject: [PATCH 11/19] lnwallet+sweep: add new method `CheckMempoolAcceptance` --- lnmock/chain.go | 159 +++++++++++++++++++++++++++ lntest/mock/walletcontroller.go | 4 + lnwallet/btcwallet/btcwallet.go | 31 ++++++ lnwallet/btcwallet/btcwallet_test.go | 90 +++++++++++++++ lnwallet/interface.go | 5 + lnwallet/mock.go | 4 + sweep/interface.go | 5 + sweep/mock_test.go | 12 ++ 8 files changed, 310 insertions(+) create mode 100644 lnmock/chain.go diff --git a/lnmock/chain.go b/lnmock/chain.go new file mode 100644 index 0000000000..dd208c33e2 --- /dev/null +++ b/lnmock/chain.go @@ -0,0 +1,159 @@ +package lnmock + +import ( + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" + "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/stretchr/testify/mock" +) + +// MockChain is a mock implementation of the Chain interface. +type MockChain struct { + mock.Mock +} + +// Compile-time constraint to ensure MockChain implements the Chain interface. +var _ chain.Interface = (*MockChain)(nil) + +func (m *MockChain) Start() error { + args := m.Called() + + return args.Error(0) +} + +func (m *MockChain) Stop() { + m.Called() +} + +func (m *MockChain) WaitForShutdown() { + m.Called() +} + +func (m *MockChain) GetBestBlock() (*chainhash.Hash, int32, error) { + args := m.Called() + + if args.Get(0) == nil { + return nil, args.Get(1).(int32), args.Error(2) + } + + return args.Get(0).(*chainhash.Hash), args.Get(1).(int32), args.Error(2) +} + +func (m *MockChain) GetBlock(hash *chainhash.Hash) (*wire.MsgBlock, error) { + args := m.Called(hash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*wire.MsgBlock), args.Error(1) +} + +func (m *MockChain) GetBlockHash(height int64) (*chainhash.Hash, error) { + args := m.Called(height) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chainhash.Hash), args.Error(1) +} + +func (m *MockChain) GetBlockHeader(hash *chainhash.Hash) ( + *wire.BlockHeader, error) { + + args := m.Called(hash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*wire.BlockHeader), args.Error(1) +} + +func (m *MockChain) IsCurrent() bool { + args := m.Called() + + return args.Bool(0) +} + +func (m *MockChain) FilterBlocks(req *chain.FilterBlocksRequest) ( + *chain.FilterBlocksResponse, error) { + + args := m.Called(req) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chain.FilterBlocksResponse), args.Error(1) +} + +func (m *MockChain) BlockStamp() (*waddrmgr.BlockStamp, error) { + args := m.Called() + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*waddrmgr.BlockStamp), args.Error(1) +} + +func (m *MockChain) SendRawTransaction(tx *wire.MsgTx, allowHighFees bool) ( + *chainhash.Hash, error) { + + args := m.Called(tx, allowHighFees) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chainhash.Hash), args.Error(1) +} + +func (m *MockChain) Rescan(startHash *chainhash.Hash, addrs []btcutil.Address, + outPoints map[wire.OutPoint]btcutil.Address) error { + + args := m.Called(startHash, addrs, outPoints) + + return args.Error(0) +} + +func (m *MockChain) NotifyReceived(addrs []btcutil.Address) error { + args := m.Called(addrs) + + return args.Error(0) +} + +func (m *MockChain) NotifyBlocks() error { + args := m.Called() + + return args.Error(0) +} + +func (m *MockChain) Notifications() <-chan interface{} { + args := m.Called() + + return args.Get(0).(<-chan interface{}) +} + +func (m *MockChain) BackEnd() string { + args := m.Called() + + return args.String(0) +} + +func (m *MockChain) TestMempoolAccept(txns []*wire.MsgTx, maxFeeRate float64) ( + []*btcjson.TestMempoolAcceptResult, error) { + + args := m.Called(txns, maxFeeRate) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).([]*btcjson.TestMempoolAcceptResult), args.Error(1) +} diff --git a/lntest/mock/walletcontroller.go b/lntest/mock/walletcontroller.go index 6d09acd54f..21d78add37 100644 --- a/lntest/mock/walletcontroller.go +++ b/lntest/mock/walletcontroller.go @@ -282,3 +282,7 @@ func (w *WalletController) FetchTx(chainhash.Hash) (*wire.MsgTx, error) { func (w *WalletController) RemoveDescendants(*wire.MsgTx) error { return nil } + +func (w *WalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} diff --git a/lnwallet/btcwallet/btcwallet.go b/lnwallet/btcwallet/btcwallet.go index ec4bc5d9b5..ebca031c54 100644 --- a/lnwallet/btcwallet/btcwallet.go +++ b/lnwallet/btcwallet/btcwallet.go @@ -1898,3 +1898,34 @@ func (b *BtcWallet) RemoveDescendants(tx *wire.MsgTx) error { return b.wallet.TxStore.RemoveUnminedTx(wtxmgrNs, txRecord) }) } + +// CheckMempoolAcceptance is a wrapper around `TestMempoolAccept` which checks +// the mempool acceptance of a transaction. +func (b *BtcWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error { + // Use a max feerate of 0 means the default value will be used when + // testing mempool acceptance. The default max feerate is 0.10 BTC/kvb, + // or 10,000 sat/vb. + results, err := b.chain.TestMempoolAccept([]*wire.MsgTx{tx}, 0) + if err != nil { + return err + } + + // Sanity check that the expected single result is returned. + if len(results) != 1 { + return fmt.Errorf("expected 1 result from TestMempoolAccept, "+ + "instead got %v", len(results)) + } + + result := results[0] + log.Debugf("TestMempoolAccept result: %s", spew.Sdump(result)) + + // Mempool check failed, we now map the reject reason to a proper RPC + // error and return it. + if !result.Allowed { + err := rpcclient.MapRPCErr(errors.New(result.RejectReason)) + + return fmt.Errorf("mempool rejection: %w", err) + } + + return nil +} diff --git a/lnwallet/btcwallet/btcwallet_test.go b/lnwallet/btcwallet/btcwallet_test.go index 28b783acc5..892ec25fdf 100644 --- a/lnwallet/btcwallet/btcwallet_test.go +++ b/lnwallet/btcwallet/btcwallet_test.go @@ -3,8 +3,12 @@ package btcwallet import ( "testing" + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" "github.com/btcsuite/btcwallet/wallet" + "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -132,3 +136,89 @@ func TestPreviousOutpoints(t *testing.T) { }) } } + +// TestCheckMempoolAcceptance asserts the CheckMempoolAcceptance behaves as +// expected. +func TestCheckMempoolAcceptance(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock chain.Interface. + mockChain := &lnmock.MockChain{} + defer mockChain.AssertExpectations(t) + + // Create a test tx and a test max feerate. + tx := wire.NewMsgTx(2) + maxFeeRate := float64(0) + + // Create a test wallet. + wallet := &BtcWallet{ + chain: mockChain, + } + + // Assert that when the chain backend doesn't support + // `TestMempoolAccept`, an error is returned. + // + // Mock the chain backend to not support `TestMempoolAccept`. + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + nil, rpcclient.ErrBackendVersion).Once() + + err := wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, rpcclient.ErrBackendVersion) + + // Assert that when the chain backend doesn't implement + // `TestMempoolAccept`, an error is returned. + // + // Mock the chain backend to not support `TestMempoolAccept`. + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + nil, chain.ErrUnimplemented).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, chain.ErrUnimplemented) + + // Assert that when the returned results are not as expected, an error + // is returned. + // + // Mock the chain backend to return more than one result. + results := []*btcjson.TestMempoolAcceptResult{ + {Txid: "txid1", Allowed: true}, + {Txid: "txid2", Allowed: false}, + } + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorContains(err, "expected 1 result from TestMempoolAccept") + + // Assert that when the tx is rejected, the reason is converted to an + // RPC error and returned. + // + // Mock the chain backend to return one result. + results = []*btcjson.TestMempoolAcceptResult{{ + Txid: tx.TxHash().String(), + Allowed: false, + RejectReason: "insufficient fee", + }} + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, rpcclient.ErrInsufficientFee) + + // Assert that when the tx is accepted, no error is returned. + // + // Mock the chain backend to return one result. + results = []*btcjson.TestMempoolAcceptResult{ + {Txid: tx.TxHash().String(), Allowed: true}, + } + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.NoError(err) +} diff --git a/lnwallet/interface.go b/lnwallet/interface.go index e26f4f2910..59e6f5aab0 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -536,6 +536,11 @@ type WalletController interface { // which could be e.g. btcd, bitcoind, neutrino, or another consensus // service. BackEnd() string + + // CheckMempoolAcceptance checks whether a transaction follows mempool + // policies and returns an error if it cannot be accepted into the + // mempool. + CheckMempoolAcceptance(tx *wire.MsgTx) error } // BlockChainIO is a dedicated source which will be used to obtain queries diff --git a/lnwallet/mock.go b/lnwallet/mock.go index f0f257ef0c..0146df57ea 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -294,6 +294,10 @@ func (w *mockWalletController) RemoveDescendants(*wire.MsgTx) error { return nil } +func (w *mockWalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} + // mockChainNotifier is a mock implementation of the ChainNotifier interface. type mockChainNotifier struct { SpendChan chan *chainntnfs.SpendDetail diff --git a/sweep/interface.go b/sweep/interface.go index a9de8bc570..e58cc8507c 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -41,4 +41,9 @@ type Wallet interface { // used to ensure that invalid transactions (inputs spent) aren't // retried in the background. CancelRebroadcast(tx chainhash.Hash) + + // CheckMempoolAcceptance checks whether a transaction follows mempool + // policies and returns an error if it cannot be accepted into the + // mempool. + CheckMempoolAcceptance(tx *wire.MsgTx) error } diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 270c3844eb..3688db72c3 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -46,6 +46,10 @@ func newMockBackend(t *testing.T, notifier *MockNotifier) *mockBackend { } } +func (b *mockBackend) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} + func (b *mockBackend) publishTransaction(tx *wire.MsgTx) error { b.lock.Lock() defer b.lock.Unlock() @@ -344,6 +348,14 @@ type MockWallet struct { // Compile-time constraint to ensure MockWallet implements Wallet. var _ Wallet = (*MockWallet)(nil) +// CheckMempoolAcceptance checks if the transaction can be accepted to the +// mempool. +func (m *MockWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error { + args := m.Called(tx) + + return args.Error(0) +} + // PublishTransaction performs cursory validation (dust checks, etc) and // broadcasts the passed transaction to the Bitcoin network. func (m *MockWallet) PublishTransaction(tx *wire.MsgTx, label string) error { From 6e84fe6223a3a0128609efe641a7be4a9c6c16be Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 29 Feb 2024 13:18:23 +0800 Subject: [PATCH 12/19] lnwallet+sweep: calculate max allowed feerate on `BumpResult` This commit adds the method `MaxFeeRateAllowed` to calculate the max fee rate. The caller may specify a large MaxFeeRate value, which cannot be cover by the budget. In that case, we default to use the max feerate calculated using `budget/weight`. --- lnwallet/chainfee/rates.go | 5 ++ sweep/fee_bumper.go | 57 +++++++++++++++++ sweep/fee_bumper_test.go | 121 +++++++++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+) diff --git a/lnwallet/chainfee/rates.go b/lnwallet/chainfee/rates.go index 6496b39c0d..98cefc13b5 100644 --- a/lnwallet/chainfee/rates.go +++ b/lnwallet/chainfee/rates.go @@ -58,6 +58,11 @@ func (s SatPerKVByte) String() string { // SatPerKWeight represents a fee rate in sat/kw. type SatPerKWeight btcutil.Amount +// NewSatPerKWeight creates a new fee rate in sat/kw. +func NewSatPerKWeight(fee btcutil.Amount, weight uint64) SatPerKWeight { + return SatPerKWeight(fee.MulF64(1000 / float64(weight))) +} + // FeeForWeight calculates the fee resulting from this fee rate and the given // weight in weight units (wu). func (s SatPerKWeight) FeeForWeight(wu int64) btcutil.Amount { diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index c406149615..b5515d10e9 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -88,6 +88,63 @@ type BumpRequest struct { MaxFeeRate chainfee.SatPerKWeight } +// MaxFeeRateAllowed returns the maximum fee rate allowed for the given +// request. It calculates the feerate using the supplied budget and the weight, +// compares it with the specified MaxFeeRate, and returns the smaller of the +// two. +func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) { + // Get the size of the sweep tx, which will be used to calculate the + // budget fee rate. + size, err := calcSweepTxWeight(r.Inputs, r.DeliveryAddress) + if err != nil { + return 0, err + } + + // Use the budget and MaxFeeRate to decide the max allowed fee rate. + // This is needed as, when the input has a large value and the user + // sets the budget to be proportional to the input value, the fee rate + // can be very high and we need to make sure it doesn't exceed the max + // fee rate. + maxFeeRateAllowed := chainfee.NewSatPerKWeight(r.Budget, size) + if maxFeeRateAllowed > r.MaxFeeRate { + log.Debugf("Budget feerate %v exceeds MaxFeeRate %v, use "+ + "MaxFeeRate instead", maxFeeRateAllowed, r.MaxFeeRate) + + return r.MaxFeeRate, nil + } + + log.Debugf("Budget feerate %v below MaxFeeRate %v, use budget feerate "+ + "instead", maxFeeRateAllowed, r.MaxFeeRate) + + return maxFeeRateAllowed, nil +} + +// calcSweepTxWeight calculates the weight of the sweep tx. It assumes a +// sweeping tx always has a single output(change). +func calcSweepTxWeight(inputs []input.Input, + outputPkScript []byte) (uint64, error) { + + // Use a const fee rate as we only use the weight estimator to + // calculate the size. + const feeRate = 1 + + // Initialize the tx weight estimator with, + // - nil outputs as we only have one single change output. + // - const fee rate as we don't care about the fees here. + // - 0 maxfeerate as we don't care about fees here. + // + // TODO(yy): we should refactor the weight estimator to not require a + // fee rate and max fee rate and make it a pure tx weight calculator. + _, estimator, err := getWeightEstimate( + inputs, nil, feeRate, 0, outputPkScript, + ) + if err != nil { + return 0, err + } + + return uint64(estimator.weight()), nil +} + // BumpResult is used by the Bumper to send updates about the tx being // broadcast. type BumpResult struct { diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 22c247b2c5..099e0aacd0 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -3,10 +3,24 @@ package sweep import ( "testing" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/require" ) +var ( + // Create a taproot change script. + changePkScript = []byte{ + 0x51, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } +) + // TestBumpResultValidate tests the validate method of the BumpResult struct. func TestBumpResultValidate(t *testing.T) { t.Parallel() @@ -50,3 +64,110 @@ func TestBumpResultValidate(t *testing.T) { } require.NoError(t, b.Validate()) } + +// TestCalcSweepTxWeight checks that the weight of the sweep tx is calculated +// correctly. +func TestCalcSweepTxWeight(t *testing.T) { + t.Parallel() + + // Create an input. + inp := createTestInput(100, input.WitnessKeyHash) + + // Use a wrong change script to test the error case. + weight, err := calcSweepTxWeight([]input.Input{&inp}, []byte{0}) + require.Error(t, err) + require.Zero(t, weight) + + // Use a correct change script to test the success case. + weight, err = calcSweepTxWeight([]input.Input{&inp}, changePkScript) + require.NoError(t, err) + + // BaseTxSize 8 bytes + // InputSize 1+41 bytes + // One P2TROutputSize 1+43 bytes + // One P2WKHWitnessSize 2+109 bytes + // Total weight = (8+42+44) * 4 + 111 = 487 + require.EqualValuesf(t, 487, weight, "unexpected weight %v", weight) +} + +// TestBumpRequestMaxFeeRateAllowed tests the max fee rate allowed for a bump +// request. +func TestBumpRequestMaxFeeRateAllowed(t *testing.T) { + t.Parallel() + + // Create a test input. + inp := createTestInput(100, input.WitnessKeyHash) + + // The weight is 487. + weight, err := calcSweepTxWeight([]input.Input{&inp}, changePkScript) + require.NoError(t, err) + + // Define a test budget and calculates its fee rate. + budget := btcutil.Amount(1000) + budgetFeeRate := chainfee.NewSatPerKWeight(budget, weight) + + testCases := []struct { + name string + req *BumpRequest + expectedMaxFeeRate chainfee.SatPerKWeight + expectedErr bool + }{ + { + // Use a wrong change script to test the error case. + name: "error calc weight", + req: &BumpRequest{ + DeliveryAddress: []byte{1}, + }, + expectedMaxFeeRate: 0, + expectedErr: true, + }, + { + // When the budget cannot give a fee rate that matches + // the supplied MaxFeeRate, the max allowed feerate is + // capped by the budget. + name: "use budget as max fee rate", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: budget, + MaxFeeRate: budgetFeeRate + 1, + }, + expectedMaxFeeRate: budgetFeeRate, + }, + { + // When the budget can give a fee rate that matches the + // supplied MaxFeeRate, the max allowed feerate is + // capped by the MaxFeeRate. + name: "use config as max fee rate", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: budget, + MaxFeeRate: budgetFeeRate - 1, + }, + expectedMaxFeeRate: budgetFeeRate - 1, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Check the method under test. + maxFeeRate, err := tc.req.MaxFeeRateAllowed() + + // If we expect an error, check the error is returned + // and the feerate is empty. + if tc.expectedErr { + require.Error(t, err) + require.Zero(t, maxFeeRate) + + return + } + + // Otherwise, check the max fee rate is as expected. + require.NoError(t, err) + require.Equal(t, tc.expectedMaxFeeRate, maxFeeRate) + }) + } +} From 4ca1de3f8843adba9e68914d5b00e7227f491ad4 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 29 Feb 2024 13:18:59 +0800 Subject: [PATCH 13/19] lnwallet+sweep: introduce `TxPublisher` to handle fee bump This commit adds `TxPublisher` which implements `Bumper` interface. This is part one of the implementation that focuses on implementing the `Broadcast` method which guarantees a tx can be published with RBF-compliant. It does so by leveraging the `testmempoolaccept` API, keep increasing the fee rate until an RBF-compliant tx is made and broadcasts it. This tx will then be monitored by the `TxPublisher` and in the following commit, the monitoring process will be added. --- chainntnfs/mocks.go | 71 ++++ input/mocks.go | 103 +++++ sweep/fee_bumper.go | 473 +++++++++++++++++++++ sweep/fee_bumper_test.go | 868 +++++++++++++++++++++++++++++++++++++++ sweep/mock_test.go | 29 ++ sweep/txgenerator.go | 9 +- 6 files changed, 1548 insertions(+), 5 deletions(-) diff --git a/chainntnfs/mocks.go b/chainntnfs/mocks.go index 31b75d46f2..d9ab9928d0 100644 --- a/chainntnfs/mocks.go +++ b/chainntnfs/mocks.go @@ -1,6 +1,7 @@ package chainntnfs import ( + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/mock" @@ -50,3 +51,73 @@ func (m *MockMempoolWatcher) LookupInputMempoolSpend( return args.Get(0).(fn.Option[wire.MsgTx]) } + +// MockNotifier is a mock implementation of the ChainNotifier interface. +type MockChainNotifier struct { + mock.Mock +} + +// Compile-time check to ensure MockChainNotifier implements ChainNotifier. +var _ ChainNotifier = (*MockChainNotifier)(nil) + +// RegisterConfirmationsNtfn registers an intent to be notified once txid +// reaches numConfs confirmations. +func (m *MockChainNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash, + pkScript []byte, numConfs, heightHint uint32, + opts ...NotifierOption) (*ConfirmationEvent, error) { + + args := m.Called(txid, pkScript, numConfs, heightHint) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*ConfirmationEvent), args.Error(1) +} + +// RegisterSpendNtfn registers an intent to be notified once the target +// outpoint is successfully spent within a transaction. +func (m *MockChainNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint, + pkScript []byte, heightHint uint32) (*SpendEvent, error) { + + args := m.Called(outpoint, pkScript, heightHint) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*SpendEvent), args.Error(1) +} + +// RegisterBlockEpochNtfn registers an intent to be notified of each new block +// connected to the tip of the main chain. +func (m *MockChainNotifier) RegisterBlockEpochNtfn(epoch *BlockEpoch) ( + *BlockEpochEvent, error) { + + args := m.Called(epoch) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*BlockEpochEvent), args.Error(1) +} + +// Start the ChainNotifier. Once started, the implementation should be ready, +// and able to receive notification registrations from clients. +func (m *MockChainNotifier) Start() error { + args := m.Called() + + return args.Error(0) +} + +// Started returns true if this instance has been started, and false otherwise. +func (m *MockChainNotifier) Started() bool { + args := m.Called() + + return args.Bool(0) +} + +// Stops the concrete ChainNotifier. +func (m *MockChainNotifier) Stop() error { + args := m.Called() + + return args.Error(0) +} diff --git a/input/mocks.go b/input/mocks.go index 1fe6eb7656..23ce6930ec 100644 --- a/input/mocks.go +++ b/input/mocks.go @@ -1,8 +1,14 @@ package input import ( + "crypto/sha256" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/keychain" "github.com/stretchr/testify/mock" ) @@ -168,3 +174,100 @@ func (m *MockWitnessType) AddWeightEstimation(e *TxWeightEstimator) error { return args.Error(0) } + +// MockInputSigner is a mock implementation of the Signer interface. +type MockInputSigner struct { + mock.Mock +} + +// Compile-time constraint to ensure MockInputSigner implements Signer. +var _ Signer = (*MockInputSigner)(nil) + +// SignOutputRaw generates a signature for the passed transaction according to +// the data within the passed SignDescriptor. +func (m *MockInputSigner) SignOutputRaw(tx *wire.MsgTx, + signDesc *SignDescriptor) (Signature, error) { + + args := m.Called(tx, signDesc) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(Signature), args.Error(1) +} + +// ComputeInputScript generates a complete InputIndex for the passed +// transaction with the signature as defined within the passed SignDescriptor. +func (m *MockInputSigner) ComputeInputScript(tx *wire.MsgTx, + signDesc *SignDescriptor) (*Script, error) { + + args := m.Called(tx, signDesc) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*Script), args.Error(1) +} + +// MuSig2CreateSession creates a new MuSig2 signing session using the local key +// identified by the key locator. +func (m *MockInputSigner) MuSig2CreateSession(version MuSig2Version, + locator keychain.KeyLocator, pubkey []*btcec.PublicKey, + tweak *MuSig2Tweaks, pubNonces [][musig2.PubNonceSize]byte, + nonces *musig2.Nonces) (*MuSig2SessionInfo, error) { + + args := m.Called(version, locator, pubkey, tweak, pubNonces, nonces) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*MuSig2SessionInfo), args.Error(1) +} + +// MuSig2RegisterNonces registers one or more public nonces of other signing +// participants for a session identified by its ID. +func (m *MockInputSigner) MuSig2RegisterNonces(versio MuSig2SessionID, + pubNonces [][musig2.PubNonceSize]byte) (bool, error) { + + args := m.Called(versio, pubNonces) + if args.Get(0) == nil { + return false, args.Error(1) + } + + return args.Bool(0), args.Error(1) +} + +// MuSig2Sign creates a partial signature using the local signing key that was +// specified when the session was created. +func (m *MockInputSigner) MuSig2Sign(sessionID MuSig2SessionID, + msg [sha256.Size]byte, withSortedKeys bool) ( + *musig2.PartialSignature, error) { + + args := m.Called(sessionID, msg, withSortedKeys) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*musig2.PartialSignature), args.Error(1) +} + +// MuSig2CombineSig combines the given partial signature(s) with the local one, +// if it already exists. +func (m *MockInputSigner) MuSig2CombineSig(sessionID MuSig2SessionID, + partialSig []*musig2.PartialSignature) ( + *schnorr.Signature, bool, error) { + + args := m.Called(sessionID, partialSig) + if args.Get(0) == nil { + return nil, false, args.Error(2) + } + + return args.Get(0).(*schnorr.Signature), args.Bool(1), args.Error(2) +} + +// MuSig2Cleanup removes a session from memory to free up resources. +func (m *MockInputSigner) MuSig2Cleanup(sessionID MuSig2SessionID) error { + args := m.Called(sessionID) + + return args.Error(0) +} diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index b5515d10e9..e0db7eb18f 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -3,16 +3,29 @@ package sweep import ( "errors" "fmt" + "sync" + "sync/atomic" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" + "github.com/lightningnetwork/lnd/lnutils" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) var ( // ErrInvalidBumpResult is returned when the bump result is invalid. ErrInvalidBumpResult = errors.New("invalid bump result") + + // ErrNotEnoughBudget is returned when the fee bumper decides the + // current budget cannot cover the fee. + ErrNotEnoughBudget = errors.New("not enough budget") ) // Bumper defines an interface that can be used by other subsystems for fee @@ -165,6 +178,9 @@ type BumpResult struct { // Err is the error that occurred during the broadcast. Err error + + // requestID is the ID of the request that created this record. + requestID uint64 } // Validate validates the BumpResult so it's safe to use. @@ -197,3 +213,460 @@ func (b *BumpResult) Validate() error { return nil } + +// TxPublisherConfig is the config used to create a new TxPublisher. +type TxPublisherConfig struct { + // Signer is used to create the tx signature. + Signer input.Signer + + // Wallet is used primarily to publish the tx. + Wallet Wallet + + // Estimator is used to estimate the fee rate for the new tx based on + // its deadline conf target. + Estimator chainfee.Estimator + + // Notifier is used to monitor the confirmation status of the tx. + Notifier chainntnfs.ChainNotifier +} + +// TxPublisher is an implementation of the Bumper interface. It utilizes the +// `testmempoolaccept` RPC to bump the fee of txns it created based on +// different fee function selected or configed by the caller. Its purpose is to +// take a list of inputs specified, and create a tx that spends them to a +// specified output. It will then monitor the confirmation status of the tx, +// and if it's not confirmed within a certain time frame, it will attempt to +// bump the fee of the tx by creating a new tx that spends the same inputs to +// the same output, but with a higher fee rate. It will continue to do this +// until the tx is confirmed or the fee rate reaches the maximum fee rate +// specified by the caller. +type TxPublisher struct { + wg sync.WaitGroup + + // cfg specifies the configuration of the TxPublisher. + cfg *TxPublisherConfig + + // currentHeight is the current block height. + currentHeight int32 + + // records is a map keyed by the requestCounter and the value is the tx + // being monitored. + records lnutils.SyncMap[uint64, *monitorRecord] + + // requestCounter is a monotonically increasing counter used to keep + // track of how many requests have been made. + requestCounter atomic.Uint64 + + // subscriberChans is a map keyed by the requestCounter, each item is + // the chan that the publisher sends the fee bump result to. + subscriberChans lnutils.SyncMap[uint64, chan *BumpResult] + + // quit is used to signal the publisher to stop. + quit chan struct{} +} + +// Compile-time constraint to ensure TxPublisher implements Bumper. +var _ Bumper = (*TxPublisher)(nil) + +// NewTxPublisher creates a new TxPublisher. +func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher { + return &TxPublisher{ + cfg: &cfg, + records: lnutils.SyncMap[uint64, *monitorRecord]{}, + subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{}, + quit: make(chan struct{}), + } +} + +// Broadcast is used to publish the tx created from the given inputs. It will, +// 1. init a fee function based on the given strategy. +// 2. create an RBF-compliant tx and monitor it for confirmation. +// 3. notify the initial broadcast result back to the caller. +// The initial broadcast is guaranteed to be RBF-compliant unless the budget +// specified cannot cover the fee. +// +// NOTE: part of the Bumper interface. +func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { + log.Tracef("Received broadcast request: %s", newLogClosure( + func() string { + return spew.Sdump(req) + })()) + + // Attempt an initial broadcast which is guaranteed to comply with the + // RBF rules. + result, err := t.initialBroadcast(req) + if err != nil { + log.Errorf("Initial broadcast failed: %v", err) + + return nil, err + } + + // Create a chan to send the result to the caller. + subscriber := make(chan *BumpResult, 1) + t.subscriberChans.Store(result.requestID, subscriber) + + // Send the initial broadcast result to the caller. + t.handleResult(result) + + return subscriber, nil +} + +// initialBroadcast initializes a fee function, creates an RBF-compliant tx and +// broadcasts it. +func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { + // Create a fee bumping algorithm to be used for future RBF. + feeAlgo, err := t.initializeFeeFunction(req) + if err != nil { + return nil, fmt.Errorf("init fee function: %w", err) + } + + // Create the initial tx to be broadcasted. This tx is guaranteed to + // comply with the RBF restrictions. + requestID, err := t.createRBFCompliantTx(req, feeAlgo) + if err != nil { + return nil, fmt.Errorf("create RBF-compliant tx: %w", err) + } + + // Broadcast the tx and return the monitored record. + result, err := t.broadcast(requestID) + if err != nil { + return nil, fmt.Errorf("broadcast sweep tx: %w", err) + } + + return result, nil +} + +// initializeFeeFunction initializes a fee function to be used for this request +// for future fee bumping. +func (t *TxPublisher) initializeFeeFunction( + req *BumpRequest) (FeeFunction, error) { + + // Get the max allowed feerate. + maxFeeRateAllowed, err := req.MaxFeeRateAllowed() + if err != nil { + return nil, err + } + + // Get the initial conf target. + confTarget := calcCurrentConfTarget(t.currentHeight, req.DeadlineHeight) + + // Initialize the fee function and return it. + // + // TODO(yy): return based on differet req.Strategy? + return NewLinearFeeFunction( + maxFeeRateAllowed, confTarget, t.cfg.Estimator, + ) +} + +// createRBFCompliantTx creates a tx that is compliant with RBF rules. It does +// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee +// and redo the process until the tx is valid, or return an error when non-RBF +// related errors occur or the budget has been used up. +func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, + f FeeFunction) (uint64, error) { + + for { + // Create a new tx with the given fee rate and check its + // mempool acceptance. + tx, fee, err := t.createAndCheckTx(req, f) + + switch { + case err == nil: + // The tx is valid, return the request ID. + requestID := t.storeRecord(tx, req, f, fee) + + log.Infof("Created tx %v for %v inputs: feerate=%v, "+ + "fee=%v, inputs=%v", tx.TxHash(), + len(req.Inputs), f.FeeRate(), fee, + inputTypeSummary(req.Inputs)) + + return requestID, nil + + // If the error indicates the fees paid is not enough, we will + // ask the fee function to increase the fee rate and retry. + case errors.Is(err, lnwallet.ErrMempoolFee): + // We should at least start with a feerate above the + // mempool min feerate, so if we get this error, it + // means something is wrong earlier in the pipeline. + log.Errorf("Current fee=%v, feerate=%v, %v", fee, + f.FeeRate(), err) + + fallthrough + + // We are not paying enough fees so we increase it. + case errors.Is(err, rpcclient.ErrInsufficientFee): + increased := false + + // Keep calling the fee function until the fee rate is + // increased or maxed out. + for !increased { + log.Debugf("Increasing fee for next round, "+ + "current fee=%v, feerate=%v", fee, + f.FeeRate()) + + // If the fee function tells us that we have + // used up the budget, we will return an error + // indicating this tx cannot be made. The + // sweeper should handle this error and try to + // cluster these inputs differetly. + increased, err = f.Increment() + if err != nil { + return 0, err + } + } + + // TODO(yy): suppose there's only one bad input, we can do a + // binary search to find out which input is causing this error + // by recreating a tx using half of the inputs and check its + // mempool acceptance. + default: + log.Debugf("Failed to create RBF-compliant tx: %v", err) + return 0, err + } + } +} + +// storeRecord stores the given record in the records map. +func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, + f FeeFunction, fee btcutil.Amount) uint64 { + + // Increase the request counter. + // + // NOTE: this is the only place where we increase the + // counter. + requestID := t.requestCounter.Add(1) + + // Register the record. + t.records.Store(requestID, &monitorRecord{ + tx: tx, + req: req, + feeFunction: f, + fee: fee, + }) + + return requestID +} + +// createAndCheckTx creates a tx based on the given inputs, change output +// script, and the fee rate. In addition, it validates the tx's mempool +// acceptance before returning a tx that can be published directly, along with +// its fee. +func (t *TxPublisher) createAndCheckTx(req *BumpRequest, f FeeFunction) ( + *wire.MsgTx, btcutil.Amount, error) { + + // Create the sweep tx with max fee rate of 0 as the fee function + // guarantees the fee rate used here won't exceed the max fee rate. + // + // TODO(yy): refactor this function to not require a max fee rate. + tx, fee, err := createSweepTx( + req.Inputs, nil, req.DeliveryAddress, uint32(t.currentHeight), + f.FeeRate(), 0, t.cfg.Signer, + ) + if err != nil { + return nil, 0, fmt.Errorf("create sweep tx: %w", err) + } + + // Sanity check the budget still covers the fee. + if fee > req.Budget { + return nil, 0, fmt.Errorf("%w: budget=%v, fee=%v", + ErrNotEnoughBudget, req.Budget, fee) + } + + // Validate the tx's mempool acceptance. + err = t.cfg.Wallet.CheckMempoolAcceptance(tx) + + // Exit early if the tx is valid. + if err == nil { + return tx, fee, nil + } + + // Print an error log if the chain backend doesn't support the mempool + // acceptance test RPC. + if errors.Is(err, rpcclient.ErrBackendVersion) { + log.Errorf("TestMempoolAccept not supported by backend, " + + "consider upgrading it to a newer version") + return tx, fee, nil + } + + // We are running on a backend that doesn't implement the RPC + // testmempoolaccept, eg, neutrino, so we'll skip the check. + if errors.Is(err, chain.ErrUnimplemented) { + log.Debug("Skipped testmempoolaccept due to not implemented") + return tx, fee, nil + } + + return nil, 0, err +} + +// broadcast takes a monitored tx and publishes it to the network. Prior to the +// broadcast, it will subscribe the tx's confirmation notification and attach +// the event channel to the record. Any broadcast-related errors will not be +// returned here, instead, they will be put inside the `BumpResult` and +// returned to the caller. +func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { + // Get the record being monitored. + record, ok := t.records.Load(requestID) + if !ok { + return nil, fmt.Errorf("tx record %v not found", requestID) + } + + txid := record.tx.TxHash() + + // Subscribe to its confirmation notification. + confEvent, err := t.cfg.Notifier.RegisterConfirmationsNtfn( + &txid, nil, 1, uint32(t.currentHeight), + ) + if err != nil { + return nil, fmt.Errorf("register confirmation ntfn: %w", err) + } + + // Attach the confirmation event channel to the record. + record.confEvent = confEvent + + tx := record.tx + log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", + txid, len(tx.TxIn), t.currentHeight) + + // Set the event, and change it to TxFailed if the wallet fails to + // publish it. + event := TxPublished + + // Publish the sweeping tx with customized label. If the publish fails, + // this error will be saved in the `BumpResult` and it will be removed + // from being monitored. + err = t.cfg.Wallet.PublishTransaction( + tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), + ) + if err != nil { + // NOTE: we decide to attach this error to the result instead + // of returning it here because by the time the tx reaches + // here, it should have passed the mempool acceptance check. If + // it still fails to be broadcast, it's likely a non-RBF + // related error happened. So we send this error back to the + // caller so that it can handle it properly. + // + // TODO(yy): find out which input is causing the failure. + log.Errorf("Failed to publish tx %v: %v", txid, err) + event = TxFailed + } + + result := &BumpResult{ + Event: event, + Tx: record.tx, + Fee: record.fee, + FeeRate: record.feeFunction.FeeRate(), + Err: err, + requestID: requestID, + } + + return result, nil +} + +// notifyResult sends the result to the resultChan specified by the requestID. +// This channel is expected to be read by the caller. +func (t *TxPublisher) notifyResult(result *BumpResult) { + id := result.requestID + subscriber, ok := t.subscriberChans.Load(id) + if !ok { + log.Errorf("Result chan for id=%v not found", id) + return + } + + log.Debugf("Sending result for requestID=%v, tx=%v", id, + result.Tx.TxHash()) + + select { + // Send the result to the subscriber. + // + // TODO(yy): Add timeout in case it's blocking? + case subscriber <- result: + case <-t.quit: + log.Debug("Fee bumper stopped") + } +} + +// removeResult removes the tracking of the result if the result contains a +// non-nil error, or the tx is confirmed, the record will be removed from the +// maps. +func (t *TxPublisher) removeResult(result *BumpResult) { + id := result.requestID + + // Remove the record from the maps if there's an error. This means this + // tx has failed its broadcast and cannot be retried. There are two + // cases, + // - when the budget cannot cover the fee. + // - when a non-RBF related error occurs. + switch result.Event { + case TxFailed: + log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v", + id, result.Tx.TxHash(), result.Err) + + case TxConfirmed: + // Remove the record is the tx is confirmed. + log.Debugf("Removing confirmed monitor record=%v, tx=%v", id, + result.Tx.TxHash()) + + // Do nothing if it's neither failed or confirmed. + default: + log.Tracef("Skipping record removal for id=%v, event=%v", id, + result.Event) + + return + } + + t.records.Delete(id) + t.subscriberChans.Delete(id) +} + +// handleResult handles the result of a tx broadcast. It will notify the +// subscriber and remove the record if the tx is confirmed or failed to be +// broadcast. +func (t *TxPublisher) handleResult(result *BumpResult) { + // Notify the subscriber. + t.notifyResult(result) + + // Remove the record if it's failed or confirmed. + t.removeResult(result) +} + +// monitorRecord is used to keep track of the tx being monitored by the +// publisher internally. +type monitorRecord struct { + // tx is the tx being monitored. + tx *wire.MsgTx + + // req is the original request. + req *BumpRequest + + // confEvent is the subscription to the confirmation event of the tx. + confEvent *chainntnfs.ConfirmationEvent + + // feeFunction is the fee bumping algorithm used by the publisher. + feeFunction FeeFunction + + // fee is the fee paid by the tx. + fee btcutil.Amount +} + +// calcCurrentConfTarget calculates the current confirmation target based on +// the deadline height. The conf target is capped at 0 if the deadline has +// already been past. +func calcCurrentConfTarget(currentHeight, deadline int32) uint32 { + var confTarget uint32 + + // Calculate how many blocks left until the deadline. + deadlineDelta := deadline - currentHeight + + // If we are already past the deadline, we will set the conf target to + // be 1. + if deadlineDelta <= 0 { + log.Warnf("Deadline is %d blocks behind current height %v", + -deadlineDelta, currentHeight) + + confTarget = 1 + } else { + confTarget = uint32(deadlineDelta) + } + + return confTarget +} diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 099e0aacd0..308a69a575 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -1,12 +1,18 @@ package sweep import ( + "fmt" "testing" + "time" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -171,3 +177,865 @@ func TestBumpRequestMaxFeeRateAllowed(t *testing.T) { }) } } + +// TestCalcCurrentConfTarget checks that the current confirmation target is +// calculated correctly. +func TestCalcCurrentConfTarget(t *testing.T) { + t.Parallel() + + // When the current block height is 100 and deadline height is 200, the + // conf target should be 100. + conf := calcCurrentConfTarget(int32(100), int32(200)) + require.EqualValues(t, 100, conf) + + // When the current block height is 200 and deadline height is 100, the + // conf target should be 1 since the deadline has passed. + conf = calcCurrentConfTarget(int32(200), int32(100)) + require.EqualValues(t, 1, conf) +} + +// TestInitializeFeeFunction tests the initialization of the fee function. +func TestInitializeFeeFunction(t *testing.T) { + t.Parallel() + + // Create a test input. + inp := createTestInput(100, input.WitnessKeyHash) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + defer estimator.AssertExpectations(t) + + // Create a publisher using the mocks. + tp := NewTxPublisher(TxPublisherConfig{ + Estimator: estimator, + }) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate, + } + + // Mock the fee estimator to return an error. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + dummyErr := fmt.Errorf("dummy error") + estimator.On("EstimateFeePerKW", mock.Anything).Return( + chainfee.SatPerKWeight(0), dummyErr).Once() + + // Call the method under test and assert the error is returned. + f, err := tp.initializeFeeFunction(req) + require.ErrorIs(t, err, dummyErr) + require.Nil(t, f) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Call the method under test. + f, err = tp.initializeFeeFunction(req) + require.NoError(t, err) + require.Equal(t, feerate, f.FeeRate()) +} + +// TestStoreRecord correctly increases the request counter and saves the +// record. +func TestStoreRecord(t *testing.T) { + t.Parallel() + + // Create a test input. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + } + + // Create a naive fee function. + feeFunc := &LinearFeeFunction{} + + // Create a test fee and tx. + fee := btcutil.Amount(1000) + tx := &wire.MsgTx{} + + // Create a publisher using the mocks. + tp := NewTxPublisher(TxPublisherConfig{}) + + // Get the current counter and check it's increased later. + initialCounter := tp.requestCounter.Load() + + // Call the method under test. + requestID := tp.storeRecord(tx, req, feeFunc, fee) + + // Check the request ID is as expected. + require.Equal(t, initialCounter+1, requestID) + + // Read the saved record and compare. + record, ok := tp.records.Load(requestID) + require.True(t, ok) + require.Equal(t, tx, record.tx) + require.Equal(t, feeFunc, record.feeFunction) + require.Equal(t, fee, record.fee) + require.Equal(t, req, record.req) +} + +// mockers wraps a list of mocked interfaces used inside tx publisher. +type mockers struct { + signer *input.MockInputSigner + wallet *MockWallet + estimator *chainfee.MockEstimator + notifier *chainntnfs.MockChainNotifier + + feeFunc *MockFeeFunction +} + +// createTestPublisher creates a new tx publisher using the provided mockers. +func createTestPublisher(t *testing.T) (*TxPublisher, *mockers) { + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create a mock fee function. + feeFunc := &MockFeeFunction{} + + // Create a mock signer. + signer := &input.MockInputSigner{} + + // Create a mock wallet. + wallet := &MockWallet{} + + // Create a mock chain notifier. + notifier := &chainntnfs.MockChainNotifier{} + + t.Cleanup(func() { + estimator.AssertExpectations(t) + feeFunc.AssertExpectations(t) + signer.AssertExpectations(t) + wallet.AssertExpectations(t) + notifier.AssertExpectations(t) + }) + + m := &mockers{ + signer: signer, + wallet: wallet, + estimator: estimator, + notifier: notifier, + feeFunc: feeFunc, + } + + // Create a publisher using the mocks. + tp := NewTxPublisher(TxPublisherConfig{ + Estimator: m.estimator, + Signer: m.signer, + Wallet: m.wallet, + Notifier: m.notifier, + }) + + return tp, m +} + +// TestCreateAndCheckTx checks `createAndCheckTx` behaves as expected. +func TestCreateAndCheckTx(t *testing.T) { + t.Parallel() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Mock the wallet to fail on testmempoolaccept on the first call, and + // succeed on the second. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + testCases := []struct { + name string + req *BumpRequest + expectedErr error + }{ + { + // When the budget cannot cover the fee, an error + // should be returned. + name: "not enough budget", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + }, + expectedErr: ErrNotEnoughBudget, + }, + { + // When the mempool rejects the transaction, an error + // should be returned. + name: "testmempoolaccept fail", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + }, + expectedErr: errDummy, + }, + { + // When the mempool accepts the transaction, no error + // should be returned. + name: "testmempoolaccept pass", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + }, + expectedErr: nil, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Call the method under test. + _, _, err := tp.createAndCheckTx(tc.req, m.feeFunc) + + // Check the result is as expected. + require.ErrorIs(t, err, tc.expectedErr) + }) + } +} + +// createTestBumpRequest creates a new bump request. +func createTestBumpRequest() *BumpRequest { + // Create a test input. + inp := createTestInput(1000, input.WitnessKeyHash) + + return &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + } +} + +// TestCreateRBFCompliantTx checks that `createRBFCompliantTx` behaves as +// expected. +func TestCreateRBFCompliantTx(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + testCases := []struct { + name string + setupMock func() + expectedErr error + }{ + { + // When testmempoolaccept accepts the tx, no error + // should be returned. + name: "success case", + setupMock: func() { + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + { + // When testmempoolaccept fails due to a non-fee + // related error, an error should be returned. + name: "non-fee related testmempoolaccept fail", + setupMock: func() { + // Mock the testmempoolaccept to fail. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + }, + expectedErr: errDummy, + }, + { + // When increase feerate gives an error, the error + // should be returned. + name: "fail on increase fee", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + lnwallet.ErrMempoolFee).Once() + + // Mock the fee function to return an error. + m.feeFunc.On("Increment").Return( + false, errDummy).Once() + }, + expectedErr: errDummy, + }, + { + // Test that after one round of increasing the feerate + // the tx passes testmempoolaccept. + name: "increase fee and success on min mempool fee", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee + // for the first call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + lnwallet.ErrMempoolFee).Once() + + // Mock the fee function to increase feerate. + m.feeFunc.On("Increment").Return( + true, nil).Once() + + // Mock the testmempoolaccept to pass on the + // second call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + { + // Test that after one round of increasing the feerate + // the tx passes testmempoolaccept. + name: "increase fee and success on insufficient fee", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee + // for the first call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + rpcclient.ErrInsufficientFee).Once() + + // Mock the fee function to increase feerate. + m.feeFunc.On("Increment").Return( + true, nil).Once() + + // Mock the testmempoolaccept to pass on the + // second call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + { + // Test that the fee function increases the fee rate + // after one round. + name: "increase fee on second round", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee + // for the first call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + rpcclient.ErrInsufficientFee).Once() + + // Mock the fee function to NOT increase + // feerate on the first round. + m.feeFunc.On("Increment").Return( + false, nil).Once() + + // Mock the fee function to increase feerate. + m.feeFunc.On("Increment").Return( + true, nil).Once() + + // Mock the testmempoolaccept to pass on the + // second call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + tc.setupMock() + + // Call the method under test. + id, err := tp.createRBFCompliantTx(req, m.feeFunc) + + // Check the result is as expected. + require.ErrorIs(t, err, tc.expectedErr) + + // If there's an error, expect the requestID to be + // empty. + if tc.expectedErr != nil { + require.Zero(t, id) + } + }) + } +} + +// TestTxPublisherBroadcast checks the internal `broadcast` method behaves as +// expected. +func TestTxPublisherBroadcast(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + txid := tx.TxHash() + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a test conf event. + confEvent := &chainntnfs.ConfirmationEvent{} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + + // Quickly check when the requestID cannot be found, an error is + // returned. + result, err := tp.broadcast(uint64(1000)) + require.Error(t, err) + require.Nil(t, result) + + // Define params to be used in RegisterConfirmationsNtfn. Not important + // for this test. + var pkScript []byte + confs := uint32(1) + height := uint32(tp.currentHeight) + + testCases := []struct { + name string + setupMock func() + expectedErr error + expectedResult *BumpResult + }{ + { + // When the notifier cannot register this spend, an + // error should be returned + name: "fail to register nftn", + setupMock: func() { + // Mock the RegisterConfirmationsNtfn to fail. + m.notifier.On("RegisterConfirmationsNtfn", + &txid, pkScript, confs, height).Return( + nil, errDummy).Once() + }, + expectedErr: errDummy, + expectedResult: nil, + }, + { + // When the wallet cannot publish this tx, the error + // should be put inside the result. + name: "fail to publish", + setupMock: func() { + // Mock the RegisterConfirmationsNtfn to pass. + m.notifier.On("RegisterConfirmationsNtfn", + &txid, pkScript, confs, height).Return( + confEvent, nil).Once() + + // Mock the wallet to fail to publish. + m.wallet.On("PublishTransaction", + tx, mock.Anything).Return( + errDummy).Once() + }, + expectedErr: nil, + expectedResult: &BumpResult{ + Event: TxFailed, + Tx: tx, + Fee: fee, + FeeRate: feerate, + Err: errDummy, + requestID: requestID, + }, + }, + { + // When nothing goes wrong, the result is returned. + name: "publish success", + setupMock: func() { + // Mock the RegisterConfirmationsNtfn to pass. + m.notifier.On("RegisterConfirmationsNtfn", + &txid, pkScript, confs, height).Return( + confEvent, nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + tx, mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + expectedResult: &BumpResult{ + Event: TxPublished, + Tx: tx, + Fee: fee, + FeeRate: feerate, + Err: nil, + requestID: requestID, + }, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + tc.setupMock() + + // Call the method under test. + result, err := tp.broadcast(requestID) + + // Check the result is as expected. + require.ErrorIs(t, err, tc.expectedErr) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +// TestRemoveResult checks the records and subscriptions are removed when a tx +// is confirmed or failed. +func TestRemoveResult(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + + testCases := []struct { + name string + setupRecord func() uint64 + result *BumpResult + removed bool + }{ + { + // When the tx is confirmed, the records will be + // removed. + name: "remove on TxConfirmed", + setupRecord: func() uint64 { + id := tp.storeRecord(tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(id, nil) + + return id + }, + result: &BumpResult{ + Event: TxConfirmed, + Tx: tx, + }, + removed: true, + }, + { + // When the tx is failed, the records will be removed. + name: "remove on TxFailed", + setupRecord: func() uint64 { + id := tp.storeRecord(tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(id, nil) + + return id + }, + result: &BumpResult{ + Event: TxFailed, + Err: errDummy, + Tx: tx, + }, + removed: true, + }, + { + // Noop when the tx is neither confirmed or failed. + name: "noop when tx is not confirmed or failed", + setupRecord: func() uint64 { + id := tp.storeRecord(tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(id, nil) + + return id + }, + result: &BumpResult{ + Event: TxPublished, + Tx: tx, + }, + removed: false, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + requestID := tc.setupRecord() + + // Attach the requestID from the setup. + tc.result.requestID = requestID + + // Remove the result. + tp.removeResult(tc.result) + + // Check if the record is removed. + _, found := tp.records.Load(requestID) + require.Equal(t, !tc.removed, found) + + _, found = tp.subscriberChans.Load(requestID) + require.Equal(t, !tc.removed, found) + }) + } +} + +// TestNotifyResult checks the subscribers are notified when a result is sent. +func TestNotifyResult(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Create a test result. + result := &BumpResult{ + requestID: requestID, + Tx: tx, + } + + // Notify the result and expect the subscriber to receive it. + // + // NOTE: must be done inside a goroutine in case it blocks. + go tp.notifyResult(result) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case received := <-subscriber: + require.Equal(t, result, received) + } + + // Notify two results. This time it should block because the channel is + // full. We then shutdown TxPublisher to test the quit behavior. + done := make(chan struct{}) + go func() { + // Call notifyResult twice, which blocks at the second call. + tp.notifyResult(result) + tp.notifyResult(result) + + close(done) + }() + + // Shutdown the publisher and expect notifyResult to exit. + close(tp.quit) + + // We expect to done chan. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for notifyResult to exit") + + case <-done: + } +} + +// TestBroadcastSuccess checks the public `Broadcast` method can successfully +// broadcast a tx based on the request. +func TestBroadcastSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Create a test conf event. + confEvent := &chainntnfs.ConfirmationEvent{} + + // Mock the RegisterConfirmationsNtfn to pass. + m.notifier.On("RegisterConfirmationsNtfn", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(confEvent, nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate, + } + + // Send the req and expect no error. + resultChan, err := tp.Broadcast(req) + require.NoError(t, err) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxPublished. + require.Equal(t, TxPublished, result.Event) + } + + // Validate the record was stored. + require.Equal(t, 1, tp.records.Len()) + require.Equal(t, 1, tp.subscriberChans.Len()) +} + +// TestBroadcastFail checks the public `Broadcast` returns the error or a +// failed result when the broadcast fails. +func TestBroadcastFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate, + } + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Twice() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + + // Send the req and expect an error returned. + resultChan, err := tp.Broadcast(req) + require.ErrorIs(t, err, errDummy) + require.Nil(t, resultChan) + + // Validate the record was NOT stored. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) + + // Mock the testmempoolaccept again, this time it passes. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Create a test conf event. + confEvent := &chainntnfs.ConfirmationEvent{} + + // Mock the RegisterConfirmationsNtfn to pass. + m.notifier.On("RegisterConfirmationsNtfn", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Return(confEvent, nil).Once() + + // Mock the wallet to fail on publish. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Send the req and expect no error returned. + resultChan, err = tp.Broadcast(req) + require.NoError(t, err) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the result to be TxFailed and the error is set in + // the result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + } + + // Validate the record was removed. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) +} diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 3688db72c3..86edbacefc 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -493,3 +493,32 @@ func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { return args.Get(0).(chan *BumpResult), args.Error(1) } + +// MockFeeFunction is a mock implementation of the FeeFunction interface. +type MockFeeFunction struct { + mock.Mock +} + +// Compile-time constraint to ensure MockFeeFunction implements FeeFunction. +var _ FeeFunction = (*MockFeeFunction)(nil) + +// FeeRate returns the current fee rate calculated by the fee function. +func (m *MockFeeFunction) FeeRate() chainfee.SatPerKWeight { + args := m.Called() + + return args.Get(0).(chainfee.SatPerKWeight) +} + +// Increment adds one delta to the current fee rate. +func (m *MockFeeFunction) Increment() (bool, error) { + args := m.Called() + + return args.Bool(0), args.Error(1) +} + +// IncreaseFeeRate increases the fee rate by one step. +func (m *MockFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) { + args := m.Called(confTarget) + + return args.Bool(0), args.Error(1) +} diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index 0cab9a6e22..57e0cce17b 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -205,15 +205,14 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, } } - log.Infof("Creating sweep transaction %v for %v inputs (%s) "+ - "using %v sat/kw, tx_weight=%v, tx_fee=%v, parents_count=%v, "+ - "parents_fee=%v, parents_weight=%v", + log.Debugf("Creating sweep transaction %v for %v inputs (%s) "+ + "using %v, tx_weight=%v, tx_fee=%v, parents_count=%v, "+ + "parents_fee=%v, parents_weight=%v, current_height=%v", sweepTx.TxHash(), len(inputs), inputTypeSummary(inputs), feeRate, estimator.weight(), txFee, len(estimator.parents), estimator.parentsFee, - estimator.parentsWeight, - ) + estimator.parentsWeight, currentBlockHeight) return sweepTx, txFee, nil } From d5b5b7622dbb67ca5756a36bd34961f5200af9f9 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 29 Feb 2024 19:36:37 +0800 Subject: [PATCH 14/19] sweep: add monitor loop to `TxPublisher` This commit finishes the implementation of `TxPublisher` by adding the monitor process. Whenever a new block arrives, the publisher will check all its monitored records and attempt fee bumping them if necessary. --- contractcourt/utxonursery.go | 4 + sweep/fee_bumper.go | 294 ++++++++++++++++++++-- sweep/fee_bumper_test.go | 464 +++++++++++++++++++++++++++++++---- sweep/interface.go | 5 + sweep/mock_test.go | 22 ++ 5 files changed, 725 insertions(+), 64 deletions(-) diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index 6b8742255b..c2f0264d35 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -1406,6 +1406,10 @@ func (k *kidOutput) ConfHeight() uint32 { return k.confHeight } +func (k *kidOutput) RequiredLockTime() (uint32, bool) { + return k.absoluteMaturity, k.absoluteMaturity > 0 +} + // Encode converts a KidOutput struct into a form suitable for on-disk database // storage. Note that the signDescriptor struct field is included so that the // output's witness can be generated by createSweepTx() when the output becomes diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index e0db7eb18f..46e9323592 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" @@ -512,17 +513,6 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { txid := record.tx.TxHash() - // Subscribe to its confirmation notification. - confEvent, err := t.cfg.Notifier.RegisterConfirmationsNtfn( - &txid, nil, 1, uint32(t.currentHeight), - ) - if err != nil { - return nil, fmt.Errorf("register confirmation ntfn: %w", err) - } - - // Attach the confirmation event channel to the record. - record.confEvent = confEvent - tx := record.tx log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", txid, len(tx.TxIn), t.currentHeight) @@ -534,7 +524,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { // Publish the sweeping tx with customized label. If the publish fails, // this error will be saved in the `BumpResult` and it will be removed // from being monitored. - err = t.cfg.Wallet.PublishTransaction( + err := t.cfg.Wallet.PublishTransaction( tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), ) if err != nil { @@ -638,9 +628,6 @@ type monitorRecord struct { // req is the original request. req *BumpRequest - // confEvent is the subscription to the confirmation event of the tx. - confEvent *chainntnfs.ConfirmationEvent - // feeFunction is the fee bumping algorithm used by the publisher. feeFunction FeeFunction @@ -648,6 +635,283 @@ type monitorRecord struct { fee btcutil.Amount } +// Start starts the publisher by subscribing to block epoch updates and kicking +// off the monitor loop. +func (t *TxPublisher) Start() error { + log.Info("TxPublisher starting...") + defer log.Debugf("TxPublisher started") + + blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + t.wg.Add(1) + go t.monitor(blockEvent) + + return nil +} + +// Stop stops the publisher and waits for the monitor loop to exit. +func (t *TxPublisher) Stop() { + log.Info("TxPublisher stopping...") + defer log.Debugf("TxPublisher stopped") + + close(t.quit) + + t.wg.Wait() +} + +// monitor is the main loop driven by new blocks. Whevenr a new block arrives, +// it will examine all the txns being monitored, and check if any of them needs +// to be bumped. If so, it will attempt to bump the fee of the tx. +// +// NOTE: Must be run as a goroutine. +func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { + defer blockEvent.Cancel() + defer t.wg.Done() + + for { + select { + case epoch, ok := <-blockEvent.Epochs: + if !ok { + // We should stop the publisher before stopping + // the chain service. Otherwise it indicates an + // error. + log.Error("Block epoch channel closed, exit " + + "monitor") + + return + } + + log.Debugf("TxPublisher received new block: %v", + epoch.Height) + + // Update the best known height for the publisher. + t.currentHeight = epoch.Height + + // Check all monitored txns to see if any of them needs + // to be bumped. + t.processRecords() + + case <-t.quit: + log.Debug("Fee bumper stopped, exit monitor") + return + } + } +} + +// processRecords checks all the txns being monitored, and checks if any of +// them needs to be bumped. If so, it will attempt to bump the fee of the tx. +func (t *TxPublisher) processRecords() { + // confirmedRecords stores a map of the records which have been + // confirmed. + confirmedRecords := make(map[uint64]*monitorRecord) + + // feeBumpRecords stores a map of the records which need to be bumped. + feeBumpRecords := make(map[uint64]*monitorRecord) + + // visitor is a helper closure that visits each record and divides them + // into two groups. + visitor := func(requestID uint64, r *monitorRecord) error { + log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, + r.tx.TxHash()) + + // If the tx is already confirmed, we can stop monitoring it. + if t.isConfirmed(r.tx.TxHash()) { + confirmedRecords[requestID] = r + + // Move to the next record. + return nil + } + + feeBumpRecords[requestID] = r + + // Return nil to move to the next record. + return nil + } + + // Iterate through all the records and divide them into two groups. + t.records.ForEach(visitor) + + // For records that are confirmed, we'll notify the caller about this + // result. + for requestID, r := range confirmedRecords { + rec := r + + log.Debugf("Tx=%v is confirmed", r.tx.TxHash()) + t.wg.Add(1) + go t.handleTxConfirmed(rec, requestID) + } + + // Get the current height to be used in the following goroutines. + currentHeight := t.currentHeight + + // For records that are not confirmed, we perform a fee bump if needed. + for requestID, r := range feeBumpRecords { + rec := r + + log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash()) + t.wg.Add(1) + go t.handleFeeBumpTx(requestID, rec, currentHeight) + } +} + +// handleTxConfirmed is called when a monitored tx is confirmed. It will +// notify the subscriber then remove the record from the maps . +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { + defer t.wg.Done() + + // Create a result that will be sent to the resultChan which is + // listened by the caller. + result := &BumpResult{ + Event: TxConfirmed, + Tx: r.tx, + requestID: requestID, + Fee: r.fee, + FeeRate: r.feeFunction.FeeRate(), + } + + // Notify that this tx is confirmed and remove the record from the map. + t.handleResult(result) +} + +// handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will +// attempt to bump the fee of the tx. +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord, + currentHeight int32) { + + defer t.wg.Done() + + oldTxid := r.tx.TxHash() + + // Get the current conf target for this record. + confTarget := calcCurrentConfTarget(currentHeight, r.req.DeadlineHeight) + + // Ask the fee function whether a bump is needed. We expect the fee + // function to increase its returned fee rate after calling this + // method. + increased, err := r.feeFunction.IncreaseFeeRate(confTarget) + if err != nil { + // TODO(yy): send this error back to the sweeper so it can + // re-group the inputs? + log.Errorf("Failed to increase fee rate for tx %v at "+ + "height=%v: %v", oldTxid, t.currentHeight, err) + + return + } + + // If the fee rate was not increased, there's no need to bump the fee. + if !increased { + log.Tracef("Skip bumping tx %v at height=%v", oldTxid, + t.currentHeight) + + return + } + + // The fee function now has a new fee rate, we will use it to bump the + // fee of the tx. + result, err := t.createAndPublishTx(requestID, r) + if err != nil { + log.Errorf("Failed to bump tx %v: %v", oldTxid, err) + + return + } + + // Notify the new result. + t.handleResult(result) +} + +// createAndPublishTx creates a new tx with a higher fee rate and publishes it +// to the network. It will update the record with the new tx and fee rate if +// successfully created, and return the result when published successfully. +func (t *TxPublisher) createAndPublishTx(requestID uint64, + r *monitorRecord) (*BumpResult, error) { + + // Fetch the old tx. + oldTx := r.tx + + // Create a new tx with the new fee rate. + // + // NOTE: The fee function is expected to have increased its returned + // fee rate after calling the SkipFeeBump method. So we can use it + // directly here. + tx, fee, err := t.createAndCheckTx(r.req, r.feeFunction) + + // If the tx doesn't not have enought budget, we will return a result + // so the sweeper can handle it by re-clustering the utxos. + if errors.Is(err, ErrNotEnoughBudget) { + log.Warnf("Fail to fee bump tx %v: %v", oldTx.TxHash(), err) + + return &BumpResult{ + Event: TxFailed, + Tx: oldTx, + Err: err, + requestID: requestID, + }, nil + } + + // If the error is not budget related, we will return an error and let + // the fee bumper retry it at next block. + // + // NOTE: we can check the RBF error here and ask the fee function to + // recalculate the fee rate. However, this would defeat the purpose of + // using a deadline based fee function: + // - if the deadline is far away, there's no rush to RBF the tx. + // - if the deadline is close, we expect the fee function to give us a + // higher fee rate. If the fee rate cannot satisfy the RBF rules, it + // means the budget is not enough. + if err != nil { + log.Infof("Failed to bump tx %v: %v", oldTx.TxHash(), err) + return nil, err + } + + // Register a new record by overwriting the same requestID. + t.records.Store(requestID, &monitorRecord{ + tx: tx, + req: r.req, + feeFunction: r.feeFunction, + fee: fee, + }) + + // Attempt to broadcast this new tx. + result, err := t.broadcast(requestID) + if err != nil { + return nil, err + } + + // A successful replacement tx is created, attach the old tx. + result.ReplacedTx = oldTx + + // If the new tx failed to be published, we will return the result so + // the caller can handle it. + if result.Event == TxFailed { + return result, nil + } + + log.Infof("Replaced tx=%v with new tx=%v", oldTx.TxHash(), tx.TxHash()) + + // Otherwise, it's a successful RBF, set the event and return. + result.Event = TxReplaced + + return result, nil +} + +// isConfirmed checks the btcwallet to see whether the tx is confirmed. +func (t *TxPublisher) isConfirmed(txid chainhash.Hash) bool { + details, err := t.cfg.Wallet.GetTransactionDetails(&txid) + if err != nil { + log.Warnf("Failed to get tx details for %v: %v", txid, err) + return false + } + + return details.NumConfirmations > 0 +} + // calcCurrentConfTarget calculates the current confirmation target based on // the deadline height. The conf target is capped at 0 if the deadline has // already been past. diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 308a69a575..f3b67f3bd9 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -612,15 +612,11 @@ func TestTxPublisherBroadcast(t *testing.T) { // Create a test tx. tx := &wire.MsgTx{LockTime: 1} - txid := tx.TxHash() // Create a test feerate and return it from the mock fee function. feerate := chainfee.SatPerKWeight(1000) m.feeFunc.On("FeeRate").Return(feerate) - // Create a test conf event. - confEvent := &chainntnfs.ConfirmationEvent{} - // Create a testing record and put it in the map. fee := btcutil.Amount(1000) requestID := tp.storeRecord(tx, req, m.feeFunc, fee) @@ -631,41 +627,17 @@ func TestTxPublisherBroadcast(t *testing.T) { require.Error(t, err) require.Nil(t, result) - // Define params to be used in RegisterConfirmationsNtfn. Not important - // for this test. - var pkScript []byte - confs := uint32(1) - height := uint32(tp.currentHeight) - testCases := []struct { name string setupMock func() expectedErr error expectedResult *BumpResult }{ - { - // When the notifier cannot register this spend, an - // error should be returned - name: "fail to register nftn", - setupMock: func() { - // Mock the RegisterConfirmationsNtfn to fail. - m.notifier.On("RegisterConfirmationsNtfn", - &txid, pkScript, confs, height).Return( - nil, errDummy).Once() - }, - expectedErr: errDummy, - expectedResult: nil, - }, { // When the wallet cannot publish this tx, the error // should be put inside the result. name: "fail to publish", setupMock: func() { - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - &txid, pkScript, confs, height).Return( - confEvent, nil).Once() - // Mock the wallet to fail to publish. m.wallet.On("PublishTransaction", tx, mock.Anything).Return( @@ -685,11 +657,6 @@ func TestTxPublisherBroadcast(t *testing.T) { // When nothing goes wrong, the result is returned. name: "publish success", setupMock: func() { - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - &txid, pkScript, confs, height).Return( - confEvent, nil).Once() - // Mock the wallet to publish successfully. m.wallet.On("PublishTransaction", tx, mock.Anything).Return(nil).Once() @@ -910,14 +877,6 @@ func TestBroadcastSuccess(t *testing.T) { // Mock the testmempoolaccept to pass. m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - // Create a test conf event. - confEvent := &chainntnfs.ConfirmationEvent{} - - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(confEvent, nil).Once() - // Mock the wallet to publish successfully. m.wallet.On("PublishTransaction", mock.Anything, mock.Anything).Return(nil).Once() @@ -1007,14 +966,6 @@ func TestBroadcastFail(t *testing.T) { // Mock the testmempoolaccept again, this time it passes. m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - // Create a test conf event. - confEvent := &chainntnfs.ConfirmationEvent{} - - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(confEvent, nil).Once() - // Mock the wallet to fail on publish. m.wallet.On("PublishTransaction", mock.Anything, mock.Anything).Return(errDummy).Once() @@ -1039,3 +990,418 @@ func TestBroadcastFail(t *testing.T) { require.Equal(t, 0, tp.records.Len()) require.Equal(t, 0, tp.subscriberChans.Len()) } + +// TestCreateAnPublishFail checks all the error cases are handled properly in +// the method createAndPublish. +func TestCreateAnPublishFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test requestID. + requestID := uint64(1) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a testing monitor record. + req := createTestBumpRequest() + + // Overwrite the budget to make it smaller than the fee. + req.Budget = 100 + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: &wire.MsgTx{}, + } + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Call the createAndPublish method. + result, err := tp.createAndPublishTx(requestID, record) + require.NoError(t, err) + + // We expect the result to be TxFailed and the error is set in the + // result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, ErrNotEnoughBudget) + require.Equal(t, requestID, result.requestID) + + // Increase the budget and call it again. This time we will mock an + // error to be returned from CheckMempoolAcceptance. + req.Budget = 1000 + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(lnwallet.ErrMempoolFee).Once() + + // Call the createAndPublish method and expect an error. + result, err = tp.createAndPublishTx(requestID, record) + require.ErrorIs(t, err, lnwallet.ErrMempoolFee) + require.Nil(t, result) +} + +// TestCreateAnPublishSuccess checks the expected result is returned from the +// method createAndPublish. +func TestCreateAnPublishSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test requestID. + requestID := uint64(1) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a testing monitor record. + req := createTestBumpRequest() + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: &wire.MsgTx{}, + } + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil) + + // Mock the wallet to publish and return an error. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Call the createAndPublish method and expect a failure result. + result, err := tp.createAndPublishTx(requestID, record) + require.NoError(t, err) + + // We expect the result to be TxFailed and the error is set. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + + // Although the replacement tx was failed to be published, the record + // should be stored. + require.NotNil(t, result.Tx) + require.NotNil(t, result.ReplacedTx) + _, found := tp.records.Load(requestID) + require.True(t, found) + + // We now check a successful RBF. + // + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call the createAndPublish method and expect a success result. + result, err = tp.createAndPublishTx(requestID, record) + require.NoError(t, err) + + // We expect the result to be TxReplaced and the error is nil. + require.Equal(t, TxReplaced, result.Event) + require.Nil(t, result.Err) + + // Check the Tx and ReplacedTx are set. + require.NotNil(t, result.Tx) + require.NotNil(t, result.ReplacedTx) + + // Check the record is stored. + _, found = tp.records.Load(requestID) + require.True(t, found) +} + +// TestHandleTxConfirmed checks the expected result is returned from the method +// handleTxConfirmed. +func TestHandleTxConfirmed(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + record, ok := tp.records.Load(requestID) + require.True(t, ok) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Mock the fee function to return a fee rate. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate).Once() + + // Call the method and expect a result to be received. + // + // NOTE: must be called in a goroutine in case it blocks. + tp.wg.Add(1) + go tp.handleTxConfirmed(record, requestID) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-subscriber: + // We expect the result to be TxConfirmed and the tx is set. + require.Equal(t, TxConfirmed, result.Event) + require.Equal(t, tx, result.Tx) + require.Nil(t, result.Err) + require.Equal(t, requestID, result.requestID) + require.Equal(t, record.fee, result.Fee) + require.Equal(t, feerate, result.FeeRate) + } + + // We expect the record to be removed from the maps. + _, found := tp.records.Load(requestID) + require.False(t, found) + _, found = tp.subscriberChans.Load(requestID) + require.False(t, found) +} + +// TestHandleFeeBumpTx validates handleFeeBumpTx behaves as expected. +func TestHandleFeeBumpTx(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a test current height. + testHeight := int32(800000) + + // Create a testing monitor record. + req := createTestBumpRequest() + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: tx, + } + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Mock the fee function to skip the bump due to error. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return( + false, errDummy).Once() + + // Call the method and expect no result received. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + // Check there's no result sent back. + select { + case <-time.After(time.Second): + case result := <-subscriber: + t.Fatalf("unexpected result received: %v", result) + } + + // Mock the fee function to skip the bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(false, nil).Once() + + // Call the method and expect no result received. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + // Check there's no result sent back. + select { + case <-time.After(time.Second): + case result := <-subscriber: + t.Fatalf("unexpected result received: %v", result) + } + + // Mock the fee function to perform the fee bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(true, nil) + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil) + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call the method and expect a result to be received. + // + // NOTE: must be called in a goroutine in case it blocks. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-subscriber: + // We expect the result to be TxReplaced. + require.Equal(t, TxReplaced, result.Event) + + // The new tx and old tx should be properly set. + require.NotEqual(t, tx, result.Tx) + require.Equal(t, tx, result.ReplacedTx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID, result.requestID) + } + + // We expect the record to NOT be removed from the maps. + _, found := tp.records.Load(requestID) + require.True(t, found) + _, found = tp.subscriberChans.Load(requestID) + require.True(t, found) +} + +// TestProcessRecords validates processRecords behaves as expected. +func TestProcessRecords(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create testing objects. + requestID1 := uint64(1) + req1 := createTestBumpRequest() + tx1 := &wire.MsgTx{LockTime: 1} + txid1 := tx1.TxHash() + + requestID2 := uint64(2) + req2 := createTestBumpRequest() + tx2 := &wire.MsgTx{LockTime: 2} + txid2 := tx2.TxHash() + + // Create a monitor record that's confirmed. + recordConfirmed := &monitorRecord{ + req: req1, + feeFunction: m.feeFunc, + tx: tx1, + } + m.wallet.On("GetTransactionDetails", &txid1).Return( + &lnwallet.TransactionDetail{ + NumConfirmations: 1, + }, nil, + ).Once() + + // Create a monitor record that's not confirmed. We know it's not + // confirmed because the num of confirms is zero. + recordFeeBump := &monitorRecord{ + req: req2, + feeFunction: m.feeFunc, + tx: tx2, + } + m.wallet.On("GetTransactionDetails", &txid2).Return( + &lnwallet.TransactionDetail{ + NumConfirmations: 0, + }, nil, + ).Once() + + // Setup the initial publisher state by adding the records to the maps. + subscriberConfirmed := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID1, subscriberConfirmed) + tp.records.Store(requestID1, recordConfirmed) + + subscriberReplaced := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID2, subscriberReplaced) + tp.records.Store(requestID2, recordFeeBump) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // The following methods should only be called once when creating the + // replacement tx. + // + // Mock the fee function to NOT skip the fee bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(true, nil).Once() + + // Mock the signer to always return a valid script. + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(&input.Script{}, nil).Once() + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call processRecords and expect the results are notified back. + tp.processRecords() + + // We expect two results to be received. One for the confirmed tx and + // one for the replaced tx. + // + // Check the confirmed tx result. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriberConfirmed") + + case result := <-subscriberConfirmed: + // We expect the result to be TxConfirmed. + require.Equal(t, TxConfirmed, result.Event) + require.Equal(t, tx1, result.Tx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID1, result.requestID) + } + + // Now check the replaced tx result. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriberReplaced") + + case result := <-subscriberReplaced: + // We expect the result to be TxReplaced. + require.Equal(t, TxReplaced, result.Event) + + // The new tx and old tx should be properly set. + require.NotEqual(t, tx2, result.Tx) + require.Equal(t, tx2, result.ReplacedTx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID2, result.requestID) + } +} diff --git a/sweep/interface.go b/sweep/interface.go index e58cc8507c..a6e5d21537 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -46,4 +46,9 @@ type Wallet interface { // policies and returns an error if it cannot be accepted into the // mempool. CheckMempoolAcceptance(tx *wire.MsgTx) error + + // GetTransactionDetails returns a detailed description of a tx given + // its transaction hash. + GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) } diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 86edbacefc..6b23953c3a 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -175,6 +175,14 @@ func (b *mockBackend) FetchTx(chainhash.Hash) (*wire.MsgTx, error) { func (b *mockBackend) CancelRebroadcast(tx chainhash.Hash) { } +// GetTransactionDetails returns a detailed description of a tx given its +// transaction hash. +func (b *mockBackend) GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) { + + return nil, nil +} + // mockFeeEstimator implements a mock fee estimator. It closely resembles // lnwallet.StaticFeeEstimator with the addition that fees can be changed for // testing purposes in a thread safe manner. @@ -418,6 +426,20 @@ func (m *MockWallet) CancelRebroadcast(tx chainhash.Hash) { m.Called(tx) } +// GetTransactionDetails returns a detailed description of a tx given its +// transaction hash. +func (m *MockWallet) GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) { + + args := m.Called(txHash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*lnwallet.TransactionDetail), args.Error(1) +} + // MockInputSet is a mock implementation of the InputSet interface. type MockInputSet struct { mock.Mock From 336533349adca6d873796fd2315e9fe29bc829df Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sun, 17 Mar 2024 14:11:04 +0800 Subject: [PATCH 15/19] lnd: init publisher when creating new server --- server.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/server.go b/server.go index 25db38b664..29a82cdc6b 100644 --- a/server.go +++ b/server.go @@ -326,6 +326,9 @@ type server struct { customMessageServer *subscribe.Server + // txPublisher is a publisher with fee-bumping capability. + txPublisher *sweep.TxPublisher + quit chan struct{} wg sync.WaitGroup @@ -1065,6 +1068,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, sweep.DefaultMaxInputsPerTx, ) + s.txPublisher = sweep.NewTxPublisher(sweep.TxPublisherConfig{ + Signer: cc.Wallet.Cfg.Signer, + Wallet: cc.Wallet, + Estimator: cc.FeeEstimator, + Notifier: cc.ChainNotifier, + }) + s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ FeeEstimator: cc.FeeEstimator, GenSweepScript: newSweepPkScriptGen(cc.Wallet), @@ -1077,6 +1087,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MaxSweepAttempts: sweep.DefaultMaxSweepAttempts, MaxFeeRate: cfg.Sweeper.MaxFeeRate, Aggregator: aggregator, + Publisher: s.txPublisher, }) s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ @@ -1931,6 +1942,15 @@ func (s *server) Start() error { cleanup = cleanup.add(s.towerClientMgr.Stop) } + if err := s.txPublisher.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(func() error { + s.txPublisher.Stop() + return nil + }) + if err := s.sweeper.Start(); err != nil { startErr = err return @@ -2264,6 +2284,9 @@ func (s *server) Stop() error { if err := s.sweeper.Stop(); err != nil { srvrLog.Warnf("failed to stop sweeper: %v", err) } + + s.txPublisher.Stop() + if err := s.channelNotifier.Stop(); err != nil { srvrLog.Warnf("failed to stop channelNotifier: %v", err) } From 01033458cfef7583fd5bb92435cf784f78ee36a8 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 18 Mar 2024 11:35:55 +0800 Subject: [PATCH 16/19] sweep: increase delta fee rate precision in fee function This commit adds a private type `mSatPerKWeight` that expresses a given fee rate in millisatoshi per kw. This is needed to increase the precision of the fee function. When sweeping anchor inputs, if using a deadline delta of over 1000, it's likely the delta will be 0 sat/kw due to precision. --- sweep/fee_function.go | 35 +++++++++++++++++++++++++++++------ sweep/fee_function_test.go | 6 +++--- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/sweep/fee_function.go b/sweep/fee_function.go index db3fb7851c..955ca43a6c 100644 --- a/sweep/fee_function.go +++ b/sweep/fee_function.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/lnwire" ) var ( @@ -14,6 +15,17 @@ var ( ErrMaxPosition = errors.New("position already at max") ) +// mSatPerKWeight represents a fee rate in msat/kw. +// +// TODO(yy): unify all the units to be virtual bytes. +type mSatPerKWeight lnwire.MilliSatoshi + +// String returns a human-readable string of the fee rate. +func (m mSatPerKWeight) String() string { + s := lnwire.MilliSatoshi(m) + return fmt.Sprintf("%v/kw", s) +} + // FeeFunction defines an interface that is used to calculate fee rates for // transactions. It's expected the implementations use three params, the // starting fee rate, the ending fee rate, and number of blocks till deadline @@ -80,8 +92,10 @@ type LinearFeeFunction struct { // and the current block height. position uint32 - // deltaFeeRate is the fee rate increase per block. - deltaFeeRate chainfee.SatPerKWeight + // deltaFeeRate is the fee rate (msat/kw) increase per block. + // + // NOTE: this is used to increase precision. + deltaFeeRate mSatPerKWeight // estimator is the fee estimator used to estimate the fee rate. We use // it to get the initial fee rate and, use it as a benchmark to decide @@ -121,21 +135,28 @@ func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight, confTarget uint32, // Calculate how much fee rate should be increased per block. end := l.endingFeeRate - delta := btcutil.Amount(end - start).MulF64(1 / float64(confTarget)) + + // The starting and ending fee rates are in sat/kw, so we need to + // convert them to msat/kw by multiplying by 1000. + delta := btcutil.Amount(end - start).MulF64(1000 / float64(confTarget)) + l.deltaFeeRate = mSatPerKWeight(delta) // We only allow the delta to be zero if the width is one - when the // delta is zero, it means the starting and ending fee rates are the // same, which means there's nothing to increase, so any width greater // than 1 doesn't provide any utility. This could happen when the // sweeper is offered to sweep an input that has passed its deadline. - if delta == 0 && l.width != 1 { + if l.deltaFeeRate == 0 && l.width != 1 { + log.Errorf("Failed to init fee function: startingFeeRate=%v, "+ + "endingFeeRate=%v, width=%v, delta=%v", start, end, + confTarget, l.deltaFeeRate) + return nil, fmt.Errorf("fee rate delta is zero") } // Attach the calculated values to the fee function. l.startingFeeRate = start l.currentFeeRate = start - l.deltaFeeRate = chainfee.SatPerKWeight(delta) log.Debugf("Linear fee function initialized with startingFeeRate=%v, "+ "endingFeeRate=%v, width=%v, delta=%v", start, end, @@ -234,7 +255,9 @@ func (l *LinearFeeFunction) feeRateAtPosition(p uint32) chainfee.SatPerKWeight { return l.endingFeeRate } - feeRateDelta := btcutil.Amount(l.deltaFeeRate).MulF64(float64(p)) + // deltaFeeRate is in msat/kw, so we need to divide by 1000 to get the + // fee rate in sat/kw. + feeRateDelta := btcutil.Amount(l.deltaFeeRate).MulF64(float64(p) / 1000) feeRate := l.startingFeeRate + chainfee.SatPerKWeight(feeRateDelta) if feeRate > l.endingFeeRate { diff --git a/sweep/fee_function_test.go b/sweep/fee_function_test.go index e7f80819aa..e549d8d643 100644 --- a/sweep/fee_function_test.go +++ b/sweep/fee_function_test.go @@ -54,8 +54,8 @@ func TestLinearFeeFunctionNew(t *testing.T) { // // Mock the fee estimator to return the fee rate. estimator.On("EstimateFeePerKW", confTarget).Return( - // The starting fee rate is 1 sat/kw less than the max fee rate. - maxFeeRate-1, nil).Once() + // The starting fee rate is the max fee rate. + maxFeeRate, nil).Once() estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() f, err = NewLinearFeeFunction(maxFeeRate, confTarget, estimator) @@ -96,7 +96,7 @@ func TestLinearFeeFunctionFeeRateAtPosition(t *testing.T) { startingFeeRate: 1000, endingFeeRate: 3000, position: 0, - deltaFeeRate: 1000, + deltaFeeRate: 1_000_000, width: 3, } From 83729e25be622056e40846f2e1cc6da3ce4ae89c Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sun, 17 Mar 2024 18:19:06 +0800 Subject: [PATCH 17/19] itest: fix existing itests --- itest/lnd_channel_force_close_test.go | 3 ++- lntest/fee_service.go | 13 +++++++++++++ lntest/harness.go | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/itest/lnd_channel_force_close_test.go b/itest/lnd_channel_force_close_test.go index 3f73c17a87..cc2e62970c 100644 --- a/itest/lnd_channel_force_close_test.go +++ b/itest/lnd_channel_force_close_test.go @@ -451,7 +451,8 @@ func channelForceClosureTest(ht *lntest.HarnessTest, // Allow some deviation because weight estimates during tx generation // are estimates. - require.InEpsilon(ht, expectedFeeRate, feeRate, 0.005) + require.InEpsilonf(ht, expectedFeeRate, feeRate, 0.005, "fee rate not "+ + "match: want %v, got %v", expectedFeeRate, feeRate) // Find alice's commit sweep and anchor sweep (if present) in the // mempool. diff --git a/lntest/fee_service.go b/lntest/fee_service.go index 49bd953ac2..d96bd75889 100644 --- a/lntest/fee_service.go +++ b/lntest/fee_service.go @@ -32,6 +32,9 @@ type WebFeeService interface { // SetFeeRate sets the estimated fee rate for a given confirmation // target. SetFeeRate(feeRate chainfee.SatPerKWeight, conf uint32) + + // Reset resets the fee rate map to the default value. + Reset() } const ( @@ -140,6 +143,16 @@ func (f *FeeService) SetFeeRate(fee chainfee.SatPerKWeight, conf uint32) { f.feeRateMap[conf] = uint32(fee.FeePerKVByte()) } +// Reset resets the fee rate map to the default value. +func (f *FeeService) Reset() { + f.lock.Lock() + f.feeRateMap = make(map[uint32]uint32) + f.lock.Unlock() + + // Initialize default fee estimate. + f.SetFeeRate(DefaultFeeRateSatPerKw, 1) +} + // URL returns the service endpoint. func (f *FeeService) URL() string { return f.url diff --git a/lntest/harness.go b/lntest/harness.go index 6960d06e17..78a2258822 100644 --- a/lntest/harness.go +++ b/lntest/harness.go @@ -396,7 +396,7 @@ func (h *HarnessTest) Subtest(t *testing.T) *HarnessTest { st.resetStandbyNodes(t) // Reset fee estimator. - st.SetFeeEstimate(DefaultFeeRateSatPerKw) + st.feeService.Reset() // Record block height. _, startHeight := h.Miner.GetBestBlock() From 1bd2589fe4efb05bee44948116d03483f97ed73f Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Sun, 17 Mar 2024 13:56:27 +0800 Subject: [PATCH 18/19] docs: update release notes for fee bumper --- docs/release-notes/release-notes-0.18.0.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 9aacfaffef..502e797f10 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -189,6 +189,9 @@ bitcoin peers' feefilter values into account](https://github.com/lightningnetwor * [Preparatory work](https://github.com/lightningnetwork/lnd/pull/8159) for forwarding of blinded routes was added. +* Introduced [fee bumper](https://github.com/lightningnetwork/lnd/pull/8424) to + handle bumping the fees of sweeping transactions properly. + ## RPC Additions * [Deprecated](https://github.com/lightningnetwork/lnd/pull/7175) From 4bf98e7d5ab3132d8ba319df94a4fbecff716ca3 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 27 Mar 2024 03:45:07 +0800 Subject: [PATCH 19/19] sweep: make sure non-fee related errors are notified So these inputs can be retried by the sweeper. --- sweep/fee_bumper.go | 77 ++++++++++++++++++++++++---------------- sweep/fee_bumper_test.go | 32 +++++++++++------ 2 files changed, 68 insertions(+), 41 deletions(-) diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 46e9323592..58a7f8b454 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnutils" @@ -815,22 +816,20 @@ func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord, // The fee function now has a new fee rate, we will use it to bump the // fee of the tx. - result, err := t.createAndPublishTx(requestID, r) - if err != nil { - log.Errorf("Failed to bump tx %v: %v", oldTxid, err) - - return - } + resultOpt := t.createAndPublishTx(requestID, r) - // Notify the new result. - t.handleResult(result) + // If there's a result, we will notify the caller about the result. + resultOpt.WhenSome(func(result BumpResult) { + // Notify the new result. + t.handleResult(&result) + }) } // createAndPublishTx creates a new tx with a higher fee rate and publishes it // to the network. It will update the record with the new tx and fee rate if // successfully created, and return the result when published successfully. func (t *TxPublisher) createAndPublishTx(requestID uint64, - r *monitorRecord) (*BumpResult, error) { + r *monitorRecord) fn.Option[BumpResult] { // Fetch the old tx. oldTx := r.tx @@ -842,21 +841,8 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, // directly here. tx, fee, err := t.createAndCheckTx(r.req, r.feeFunction) - // If the tx doesn't not have enought budget, we will return a result - // so the sweeper can handle it by re-clustering the utxos. - if errors.Is(err, ErrNotEnoughBudget) { - log.Warnf("Fail to fee bump tx %v: %v", oldTx.TxHash(), err) - - return &BumpResult{ - Event: TxFailed, - Tx: oldTx, - Err: err, - requestID: requestID, - }, nil - } - - // If the error is not budget related, we will return an error and let - // the fee bumper retry it at next block. + // If the error is fee related, we will return an error and let the fee + // bumper retry it at next block. // // NOTE: we can check the RBF error here and ask the fee function to // recalculate the fee rate. However, this would defeat the purpose of @@ -865,12 +851,40 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, // - if the deadline is close, we expect the fee function to give us a // higher fee rate. If the fee rate cannot satisfy the RBF rules, it // means the budget is not enough. + if errors.Is(err, rpcclient.ErrInsufficientFee) || + errors.Is(err, lnwallet.ErrMempoolFee) { + + log.Debugf("Failed to bump tx %v: %v", oldTx.TxHash(), err) + return fn.None[BumpResult]() + } + + // If the error is not fee related, we will return a `TxFailed` event + // so this input can be retried. if err != nil { - log.Infof("Failed to bump tx %v: %v", oldTx.TxHash(), err) - return nil, err + // If the tx doesn't not have enought budget, we will return a + // result so the sweeper can handle it by re-clustering the + // utxos. + if errors.Is(err, ErrNotEnoughBudget) { + log.Warnf("Fail to fee bump tx %v: %v", oldTx.TxHash(), + err) + } else { + // Otherwise, an unexpected error occurred, we will + // fail the tx and let the sweeper retry the whole + // process. + log.Errorf("Failed to bump tx %v: %v", oldTx.TxHash(), + err) + } + + return fn.Some(BumpResult{ + Event: TxFailed, + Tx: oldTx, + Err: err, + requestID: requestID, + }) } - // Register a new record by overwriting the same requestID. + // The tx has been created without any errors, we now register a new + // record by overwriting the same requestID. t.records.Store(requestID, &monitorRecord{ tx: tx, req: r.req, @@ -881,7 +895,10 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, // Attempt to broadcast this new tx. result, err := t.broadcast(requestID) if err != nil { - return nil, err + log.Infof("Failed to broadcast replacement tx %v: %v", + tx.TxHash(), err) + + return fn.None[BumpResult]() } // A successful replacement tx is created, attach the old tx. @@ -890,7 +907,7 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, // If the new tx failed to be published, we will return the result so // the caller can handle it. if result.Event == TxFailed { - return result, nil + return fn.Some(*result) } log.Infof("Replaced tx=%v with new tx=%v", oldTx.TxHash(), tx.TxHash()) @@ -898,7 +915,7 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, // Otherwise, it's a successful RBF, set the event and return. result.Event = TxReplaced - return result, nil + return fn.Some(*result) } // isConfirmed checks the btcwallet to see whether the tx is confirmed. diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index f3b67f3bd9..5f031a9bff 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -1027,8 +1027,8 @@ func TestCreateAnPublishFail(t *testing.T) { mock.Anything).Return(script, nil) // Call the createAndPublish method. - result, err := tp.createAndPublishTx(requestID, record) - require.NoError(t, err) + resultOpt := tp.createAndPublishTx(requestID, record) + result := resultOpt.UnwrapOrFail(t) // We expect the result to be TxFailed and the error is set in the // result. @@ -1040,14 +1040,23 @@ func TestCreateAnPublishFail(t *testing.T) { // error to be returned from CheckMempoolAcceptance. req.Budget = 1000 - // Mock the testmempoolaccept to return an error. + // Mock the testmempoolaccept to return a fee related error that should + // be ignored. m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(lnwallet.ErrMempoolFee).Once() - // Call the createAndPublish method and expect an error. - result, err = tp.createAndPublishTx(requestID, record) - require.ErrorIs(t, err, lnwallet.ErrMempoolFee) - require.Nil(t, result) + // Call the createAndPublish method and expect a none option. + resultOpt = tp.createAndPublishTx(requestID, record) + require.True(t, resultOpt.IsNone()) + + // Mock the testmempoolaccept to return a fee related error that should + // be ignored. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(rpcclient.ErrInsufficientFee).Once() + + // Call the createAndPublish method and expect a none option. + resultOpt = tp.createAndPublishTx(requestID, record) + require.True(t, resultOpt.IsNone()) } // TestCreateAnPublishSuccess checks the expected result is returned from the @@ -1090,8 +1099,8 @@ func TestCreateAnPublishSuccess(t *testing.T) { mock.Anything, mock.Anything).Return(errDummy).Once() // Call the createAndPublish method and expect a failure result. - result, err := tp.createAndPublishTx(requestID, record) - require.NoError(t, err) + resultOpt := tp.createAndPublishTx(requestID, record) + result := resultOpt.UnwrapOrFail(t) // We expect the result to be TxFailed and the error is set. require.Equal(t, TxFailed, result.Event) @@ -1111,8 +1120,9 @@ func TestCreateAnPublishSuccess(t *testing.T) { mock.Anything, mock.Anything).Return(nil).Once() // Call the createAndPublish method and expect a success result. - result, err = tp.createAndPublishTx(requestID, record) - require.NoError(t, err) + resultOpt = tp.createAndPublishTx(requestID, record) + result = resultOpt.UnwrapOrFail(t) + require.True(t, resultOpt.IsSome()) // We expect the result to be TxReplaced and the error is nil. require.Equal(t, TxReplaced, result.Event)