diff --git a/chainntnfs/mocks.go b/chainntnfs/mocks.go index 31b75d46f2..d9ab9928d0 100644 --- a/chainntnfs/mocks.go +++ b/chainntnfs/mocks.go @@ -1,6 +1,7 @@ package chainntnfs import ( + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/mock" @@ -50,3 +51,73 @@ func (m *MockMempoolWatcher) LookupInputMempoolSpend( return args.Get(0).(fn.Option[wire.MsgTx]) } + +// MockNotifier is a mock implementation of the ChainNotifier interface. +type MockChainNotifier struct { + mock.Mock +} + +// Compile-time check to ensure MockChainNotifier implements ChainNotifier. +var _ ChainNotifier = (*MockChainNotifier)(nil) + +// RegisterConfirmationsNtfn registers an intent to be notified once txid +// reaches numConfs confirmations. +func (m *MockChainNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash, + pkScript []byte, numConfs, heightHint uint32, + opts ...NotifierOption) (*ConfirmationEvent, error) { + + args := m.Called(txid, pkScript, numConfs, heightHint) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*ConfirmationEvent), args.Error(1) +} + +// RegisterSpendNtfn registers an intent to be notified once the target +// outpoint is successfully spent within a transaction. +func (m *MockChainNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint, + pkScript []byte, heightHint uint32) (*SpendEvent, error) { + + args := m.Called(outpoint, pkScript, heightHint) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*SpendEvent), args.Error(1) +} + +// RegisterBlockEpochNtfn registers an intent to be notified of each new block +// connected to the tip of the main chain. +func (m *MockChainNotifier) RegisterBlockEpochNtfn(epoch *BlockEpoch) ( + *BlockEpochEvent, error) { + + args := m.Called(epoch) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*BlockEpochEvent), args.Error(1) +} + +// Start the ChainNotifier. Once started, the implementation should be ready, +// and able to receive notification registrations from clients. +func (m *MockChainNotifier) Start() error { + args := m.Called() + + return args.Error(0) +} + +// Started returns true if this instance has been started, and false otherwise. +func (m *MockChainNotifier) Started() bool { + args := m.Called() + + return args.Bool(0) +} + +// Stops the concrete ChainNotifier. +func (m *MockChainNotifier) Stop() error { + args := m.Called() + + return args.Error(0) +} diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index 6b8742255b..c2f0264d35 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -1406,6 +1406,10 @@ func (k *kidOutput) ConfHeight() uint32 { return k.confHeight } +func (k *kidOutput) RequiredLockTime() (uint32, bool) { + return k.absoluteMaturity, k.absoluteMaturity > 0 +} + // Encode converts a KidOutput struct into a form suitable for on-disk database // storage. Note that the signDescriptor struct field is included so that the // output's witness can be generated by createSweepTx() when the output becomes diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 9aacfaffef..502e797f10 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -189,6 +189,9 @@ bitcoin peers' feefilter values into account](https://github.com/lightningnetwor * [Preparatory work](https://github.com/lightningnetwork/lnd/pull/8159) for forwarding of blinded routes was added. +* Introduced [fee bumper](https://github.com/lightningnetwork/lnd/pull/8424) to + handle bumping the fees of sweeping transactions properly. + ## RPC Additions * [Deprecated](https://github.com/lightningnetwork/lnd/pull/7175) diff --git a/input/mocks.go b/input/mocks.go index 965489effb..23ce6930ec 100644 --- a/input/mocks.go +++ b/input/mocks.go @@ -1,8 +1,14 @@ package input import ( + "crypto/sha256" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/keychain" "github.com/stretchr/testify/mock" ) @@ -123,3 +129,145 @@ func (m *MockInput) UnconfParent() *TxInfo { return info.(*TxInfo) } + +// MockWitnessType implements the `WitnessType` interface and is used by other +// packages for mock testing. +type MockWitnessType struct { + mock.Mock +} + +// Compile time assertion that MockWitnessType implements WitnessType. +var _ WitnessType = (*MockWitnessType)(nil) + +// String returns a human readable version of the WitnessType. +func (m *MockWitnessType) String() string { + args := m.Called() + + return args.String(0) +} + +// WitnessGenerator will return a WitnessGenerator function that an output uses +// to generate the witness and optionally the sigScript for a sweep +// transaction. +func (m *MockWitnessType) WitnessGenerator(signer Signer, + descriptor *SignDescriptor) WitnessGenerator { + + args := m.Called() + + return args.Get(0).(WitnessGenerator) +} + +// SizeUpperBound returns the maximum length of the witness of this WitnessType +// if it would be included in a tx. It also returns if the output itself is a +// nested p2sh output, if so then we need to take into account the extra +// sigScript data size. +func (m *MockWitnessType) SizeUpperBound() (int, bool, error) { + args := m.Called() + + return args.Int(0), args.Bool(1), args.Error(2) +} + +// AddWeightEstimation adds the estimated size of the witness in bytes to the +// given weight estimator. +func (m *MockWitnessType) AddWeightEstimation(e *TxWeightEstimator) error { + args := m.Called() + + return args.Error(0) +} + +// MockInputSigner is a mock implementation of the Signer interface. +type MockInputSigner struct { + mock.Mock +} + +// Compile-time constraint to ensure MockInputSigner implements Signer. +var _ Signer = (*MockInputSigner)(nil) + +// SignOutputRaw generates a signature for the passed transaction according to +// the data within the passed SignDescriptor. +func (m *MockInputSigner) SignOutputRaw(tx *wire.MsgTx, + signDesc *SignDescriptor) (Signature, error) { + + args := m.Called(tx, signDesc) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(Signature), args.Error(1) +} + +// ComputeInputScript generates a complete InputIndex for the passed +// transaction with the signature as defined within the passed SignDescriptor. +func (m *MockInputSigner) ComputeInputScript(tx *wire.MsgTx, + signDesc *SignDescriptor) (*Script, error) { + + args := m.Called(tx, signDesc) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*Script), args.Error(1) +} + +// MuSig2CreateSession creates a new MuSig2 signing session using the local key +// identified by the key locator. +func (m *MockInputSigner) MuSig2CreateSession(version MuSig2Version, + locator keychain.KeyLocator, pubkey []*btcec.PublicKey, + tweak *MuSig2Tweaks, pubNonces [][musig2.PubNonceSize]byte, + nonces *musig2.Nonces) (*MuSig2SessionInfo, error) { + + args := m.Called(version, locator, pubkey, tweak, pubNonces, nonces) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*MuSig2SessionInfo), args.Error(1) +} + +// MuSig2RegisterNonces registers one or more public nonces of other signing +// participants for a session identified by its ID. +func (m *MockInputSigner) MuSig2RegisterNonces(versio MuSig2SessionID, + pubNonces [][musig2.PubNonceSize]byte) (bool, error) { + + args := m.Called(versio, pubNonces) + if args.Get(0) == nil { + return false, args.Error(1) + } + + return args.Bool(0), args.Error(1) +} + +// MuSig2Sign creates a partial signature using the local signing key that was +// specified when the session was created. +func (m *MockInputSigner) MuSig2Sign(sessionID MuSig2SessionID, + msg [sha256.Size]byte, withSortedKeys bool) ( + *musig2.PartialSignature, error) { + + args := m.Called(sessionID, msg, withSortedKeys) + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*musig2.PartialSignature), args.Error(1) +} + +// MuSig2CombineSig combines the given partial signature(s) with the local one, +// if it already exists. +func (m *MockInputSigner) MuSig2CombineSig(sessionID MuSig2SessionID, + partialSig []*musig2.PartialSignature) ( + *schnorr.Signature, bool, error) { + + args := m.Called(sessionID, partialSig) + if args.Get(0) == nil { + return nil, false, args.Error(2) + } + + return args.Get(0).(*schnorr.Signature), args.Bool(1), args.Error(2) +} + +// MuSig2Cleanup removes a session from memory to free up resources. +func (m *MockInputSigner) MuSig2Cleanup(sessionID MuSig2SessionID) error { + args := m.Called(sessionID) + + return args.Error(0) +} diff --git a/itest/lnd_channel_force_close_test.go b/itest/lnd_channel_force_close_test.go index 3f73c17a87..cc2e62970c 100644 --- a/itest/lnd_channel_force_close_test.go +++ b/itest/lnd_channel_force_close_test.go @@ -451,7 +451,8 @@ func channelForceClosureTest(ht *lntest.HarnessTest, // Allow some deviation because weight estimates during tx generation // are estimates. - require.InEpsilon(ht, expectedFeeRate, feeRate, 0.005) + require.InEpsilonf(ht, expectedFeeRate, feeRate, 0.005, "fee rate not "+ + "match: want %v, got %v", expectedFeeRate, feeRate) // Find alice's commit sweep and anchor sweep (if present) in the // mempool. diff --git a/lnmock/chain.go b/lnmock/chain.go new file mode 100644 index 0000000000..dd208c33e2 --- /dev/null +++ b/lnmock/chain.go @@ -0,0 +1,159 @@ +package lnmock + +import ( + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" + "github.com/btcsuite/btcwallet/waddrmgr" + "github.com/stretchr/testify/mock" +) + +// MockChain is a mock implementation of the Chain interface. +type MockChain struct { + mock.Mock +} + +// Compile-time constraint to ensure MockChain implements the Chain interface. +var _ chain.Interface = (*MockChain)(nil) + +func (m *MockChain) Start() error { + args := m.Called() + + return args.Error(0) +} + +func (m *MockChain) Stop() { + m.Called() +} + +func (m *MockChain) WaitForShutdown() { + m.Called() +} + +func (m *MockChain) GetBestBlock() (*chainhash.Hash, int32, error) { + args := m.Called() + + if args.Get(0) == nil { + return nil, args.Get(1).(int32), args.Error(2) + } + + return args.Get(0).(*chainhash.Hash), args.Get(1).(int32), args.Error(2) +} + +func (m *MockChain) GetBlock(hash *chainhash.Hash) (*wire.MsgBlock, error) { + args := m.Called(hash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*wire.MsgBlock), args.Error(1) +} + +func (m *MockChain) GetBlockHash(height int64) (*chainhash.Hash, error) { + args := m.Called(height) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chainhash.Hash), args.Error(1) +} + +func (m *MockChain) GetBlockHeader(hash *chainhash.Hash) ( + *wire.BlockHeader, error) { + + args := m.Called(hash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*wire.BlockHeader), args.Error(1) +} + +func (m *MockChain) IsCurrent() bool { + args := m.Called() + + return args.Bool(0) +} + +func (m *MockChain) FilterBlocks(req *chain.FilterBlocksRequest) ( + *chain.FilterBlocksResponse, error) { + + args := m.Called(req) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chain.FilterBlocksResponse), args.Error(1) +} + +func (m *MockChain) BlockStamp() (*waddrmgr.BlockStamp, error) { + args := m.Called() + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*waddrmgr.BlockStamp), args.Error(1) +} + +func (m *MockChain) SendRawTransaction(tx *wire.MsgTx, allowHighFees bool) ( + *chainhash.Hash, error) { + + args := m.Called(tx, allowHighFees) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*chainhash.Hash), args.Error(1) +} + +func (m *MockChain) Rescan(startHash *chainhash.Hash, addrs []btcutil.Address, + outPoints map[wire.OutPoint]btcutil.Address) error { + + args := m.Called(startHash, addrs, outPoints) + + return args.Error(0) +} + +func (m *MockChain) NotifyReceived(addrs []btcutil.Address) error { + args := m.Called(addrs) + + return args.Error(0) +} + +func (m *MockChain) NotifyBlocks() error { + args := m.Called() + + return args.Error(0) +} + +func (m *MockChain) Notifications() <-chan interface{} { + args := m.Called() + + return args.Get(0).(<-chan interface{}) +} + +func (m *MockChain) BackEnd() string { + args := m.Called() + + return args.String(0) +} + +func (m *MockChain) TestMempoolAccept(txns []*wire.MsgTx, maxFeeRate float64) ( + []*btcjson.TestMempoolAcceptResult, error) { + + args := m.Called(txns, maxFeeRate) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).([]*btcjson.TestMempoolAcceptResult), args.Error(1) +} diff --git a/lntest/fee_service.go b/lntest/fee_service.go index 49bd953ac2..d96bd75889 100644 --- a/lntest/fee_service.go +++ b/lntest/fee_service.go @@ -32,6 +32,9 @@ type WebFeeService interface { // SetFeeRate sets the estimated fee rate for a given confirmation // target. SetFeeRate(feeRate chainfee.SatPerKWeight, conf uint32) + + // Reset resets the fee rate map to the default value. + Reset() } const ( @@ -140,6 +143,16 @@ func (f *FeeService) SetFeeRate(fee chainfee.SatPerKWeight, conf uint32) { f.feeRateMap[conf] = uint32(fee.FeePerKVByte()) } +// Reset resets the fee rate map to the default value. +func (f *FeeService) Reset() { + f.lock.Lock() + f.feeRateMap = make(map[uint32]uint32) + f.lock.Unlock() + + // Initialize default fee estimate. + f.SetFeeRate(DefaultFeeRateSatPerKw, 1) +} + // URL returns the service endpoint. func (f *FeeService) URL() string { return f.url diff --git a/lntest/harness.go b/lntest/harness.go index 6960d06e17..78a2258822 100644 --- a/lntest/harness.go +++ b/lntest/harness.go @@ -396,7 +396,7 @@ func (h *HarnessTest) Subtest(t *testing.T) *HarnessTest { st.resetStandbyNodes(t) // Reset fee estimator. - st.SetFeeEstimate(DefaultFeeRateSatPerKw) + st.feeService.Reset() // Record block height. _, startHeight := h.Miner.GetBestBlock() diff --git a/lntest/mock/walletcontroller.go b/lntest/mock/walletcontroller.go index 6d09acd54f..21d78add37 100644 --- a/lntest/mock/walletcontroller.go +++ b/lntest/mock/walletcontroller.go @@ -282,3 +282,7 @@ func (w *WalletController) FetchTx(chainhash.Hash) (*wire.MsgTx, error) { func (w *WalletController) RemoveDescendants(*wire.MsgTx) error { return nil } + +func (w *WalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} diff --git a/lnwallet/btcwallet/btcwallet.go b/lnwallet/btcwallet/btcwallet.go index ec4bc5d9b5..ebca031c54 100644 --- a/lnwallet/btcwallet/btcwallet.go +++ b/lnwallet/btcwallet/btcwallet.go @@ -1898,3 +1898,34 @@ func (b *BtcWallet) RemoveDescendants(tx *wire.MsgTx) error { return b.wallet.TxStore.RemoveUnminedTx(wtxmgrNs, txRecord) }) } + +// CheckMempoolAcceptance is a wrapper around `TestMempoolAccept` which checks +// the mempool acceptance of a transaction. +func (b *BtcWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error { + // Use a max feerate of 0 means the default value will be used when + // testing mempool acceptance. The default max feerate is 0.10 BTC/kvb, + // or 10,000 sat/vb. + results, err := b.chain.TestMempoolAccept([]*wire.MsgTx{tx}, 0) + if err != nil { + return err + } + + // Sanity check that the expected single result is returned. + if len(results) != 1 { + return fmt.Errorf("expected 1 result from TestMempoolAccept, "+ + "instead got %v", len(results)) + } + + result := results[0] + log.Debugf("TestMempoolAccept result: %s", spew.Sdump(result)) + + // Mempool check failed, we now map the reject reason to a proper RPC + // error and return it. + if !result.Allowed { + err := rpcclient.MapRPCErr(errors.New(result.RejectReason)) + + return fmt.Errorf("mempool rejection: %w", err) + } + + return nil +} diff --git a/lnwallet/btcwallet/btcwallet_test.go b/lnwallet/btcwallet/btcwallet_test.go index 28b783acc5..892ec25fdf 100644 --- a/lnwallet/btcwallet/btcwallet_test.go +++ b/lnwallet/btcwallet/btcwallet_test.go @@ -3,8 +3,12 @@ package btcwallet import ( "testing" + "github.com/btcsuite/btcd/btcjson" + "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" "github.com/btcsuite/btcwallet/wallet" + "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -132,3 +136,89 @@ func TestPreviousOutpoints(t *testing.T) { }) } } + +// TestCheckMempoolAcceptance asserts the CheckMempoolAcceptance behaves as +// expected. +func TestCheckMempoolAcceptance(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock chain.Interface. + mockChain := &lnmock.MockChain{} + defer mockChain.AssertExpectations(t) + + // Create a test tx and a test max feerate. + tx := wire.NewMsgTx(2) + maxFeeRate := float64(0) + + // Create a test wallet. + wallet := &BtcWallet{ + chain: mockChain, + } + + // Assert that when the chain backend doesn't support + // `TestMempoolAccept`, an error is returned. + // + // Mock the chain backend to not support `TestMempoolAccept`. + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + nil, rpcclient.ErrBackendVersion).Once() + + err := wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, rpcclient.ErrBackendVersion) + + // Assert that when the chain backend doesn't implement + // `TestMempoolAccept`, an error is returned. + // + // Mock the chain backend to not support `TestMempoolAccept`. + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + nil, chain.ErrUnimplemented).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, chain.ErrUnimplemented) + + // Assert that when the returned results are not as expected, an error + // is returned. + // + // Mock the chain backend to return more than one result. + results := []*btcjson.TestMempoolAcceptResult{ + {Txid: "txid1", Allowed: true}, + {Txid: "txid2", Allowed: false}, + } + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorContains(err, "expected 1 result from TestMempoolAccept") + + // Assert that when the tx is rejected, the reason is converted to an + // RPC error and returned. + // + // Mock the chain backend to return one result. + results = []*btcjson.TestMempoolAcceptResult{{ + Txid: tx.TxHash().String(), + Allowed: false, + RejectReason: "insufficient fee", + }} + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.ErrorIs(err, rpcclient.ErrInsufficientFee) + + // Assert that when the tx is accepted, no error is returned. + // + // Mock the chain backend to return one result. + results = []*btcjson.TestMempoolAcceptResult{ + {Txid: tx.TxHash().String(), Allowed: true}, + } + mockChain.On("TestMempoolAccept", []*wire.MsgTx{tx}, maxFeeRate).Return( + results, nil).Once() + + // Now call the method under test. + err = wallet.CheckMempoolAcceptance(tx) + rt.NoError(err) +} diff --git a/lnwallet/chainfee/rates.go b/lnwallet/chainfee/rates.go index 6496b39c0d..98cefc13b5 100644 --- a/lnwallet/chainfee/rates.go +++ b/lnwallet/chainfee/rates.go @@ -58,6 +58,11 @@ func (s SatPerKVByte) String() string { // SatPerKWeight represents a fee rate in sat/kw. type SatPerKWeight btcutil.Amount +// NewSatPerKWeight creates a new fee rate in sat/kw. +func NewSatPerKWeight(fee btcutil.Amount, weight uint64) SatPerKWeight { + return SatPerKWeight(fee.MulF64(1000 / float64(weight))) +} + // FeeForWeight calculates the fee resulting from this fee rate and the given // weight in weight units (wu). func (s SatPerKWeight) FeeForWeight(wu int64) btcutil.Amount { diff --git a/lnwallet/interface.go b/lnwallet/interface.go index e26f4f2910..59e6f5aab0 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -536,6 +536,11 @@ type WalletController interface { // which could be e.g. btcd, bitcoind, neutrino, or another consensus // service. BackEnd() string + + // CheckMempoolAcceptance checks whether a transaction follows mempool + // policies and returns an error if it cannot be accepted into the + // mempool. + CheckMempoolAcceptance(tx *wire.MsgTx) error } // BlockChainIO is a dedicated source which will be used to obtain queries diff --git a/lnwallet/mock.go b/lnwallet/mock.go index f0f257ef0c..0146df57ea 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -294,6 +294,10 @@ func (w *mockWalletController) RemoveDescendants(*wire.MsgTx) error { return nil } +func (w *mockWalletController) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} + // mockChainNotifier is a mock implementation of the ChainNotifier interface. type mockChainNotifier struct { SpendChan chan *chainntnfs.SpendDetail diff --git a/server.go b/server.go index 25db38b664..29a82cdc6b 100644 --- a/server.go +++ b/server.go @@ -326,6 +326,9 @@ type server struct { customMessageServer *subscribe.Server + // txPublisher is a publisher with fee-bumping capability. + txPublisher *sweep.TxPublisher + quit chan struct{} wg sync.WaitGroup @@ -1065,6 +1068,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, sweep.DefaultMaxInputsPerTx, ) + s.txPublisher = sweep.NewTxPublisher(sweep.TxPublisherConfig{ + Signer: cc.Wallet.Cfg.Signer, + Wallet: cc.Wallet, + Estimator: cc.FeeEstimator, + Notifier: cc.ChainNotifier, + }) + s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ FeeEstimator: cc.FeeEstimator, GenSweepScript: newSweepPkScriptGen(cc.Wallet), @@ -1077,6 +1087,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MaxSweepAttempts: sweep.DefaultMaxSweepAttempts, MaxFeeRate: cfg.Sweeper.MaxFeeRate, Aggregator: aggregator, + Publisher: s.txPublisher, }) s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ @@ -1931,6 +1942,15 @@ func (s *server) Start() error { cleanup = cleanup.add(s.towerClientMgr.Stop) } + if err := s.txPublisher.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(func() error { + s.txPublisher.Stop() + return nil + }) + if err := s.sweeper.Start(); err != nil { startErr = err return @@ -2264,6 +2284,9 @@ func (s *server) Stop() error { if err := s.sweeper.Stop(); err != nil { srvrLog.Warnf("failed to stop sweeper: %v", err) } + + s.txPublisher.Stop() + if err := s.channelNotifier.Stop(); err != nil { srvrLog.Warnf("failed to stop channelNotifier: %v", err) } diff --git a/sweep/aggregator.go b/sweep/aggregator.go index 379ff98296..58ac511320 100644 --- a/sweep/aggregator.go +++ b/sweep/aggregator.go @@ -3,7 +3,10 @@ package sweep import ( "sort" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) @@ -461,3 +464,232 @@ func zipClusters(as, bs []inputCluster) []inputCluster { return finalClusters } + +// BudgetAggregator is a budget-based aggregator that creates clusters based on +// deadlines and budgets of inputs. +type BudgetAggregator struct { + // estimator is used when crafting sweep transactions to estimate the + // necessary fee relative to the expected size of the sweep + // transaction. + estimator chainfee.Estimator + + // maxInputs specifies the maximum number of inputs allowed in a single + // sweep tx. + maxInputs uint32 +} + +// Compile-time constraint to ensure BudgetAggregator implements UtxoAggregator. +var _ UtxoAggregator = (*BudgetAggregator)(nil) + +// NewBudgetAggregator creates a new instance of a BudgetAggregator. +func NewBudgetAggregator(estimator chainfee.Estimator, + maxInputs uint32) *BudgetAggregator { + + return &BudgetAggregator{ + estimator: estimator, + maxInputs: maxInputs, + } +} + +// clusterGroup defines an alias for a set of inputs that are to be grouped. +type clusterGroup map[fn.Option[int32]][]pendingInput + +// ClusterInputs creates a list of input sets from pending inputs. +// 1. filter out inputs whose budget cannot cover min relay fee. +// 2. group the inputs into clusters based on their deadline height. +// 3. sort the inputs in each cluster by their budget. +// 4. optionally split a cluster if it exceeds the max input limit. +// 5. create input sets from each of the clusters. +func (b *BudgetAggregator) ClusterInputs(inputs pendingInputs) []InputSet { + // Filter out inputs that have a budget below min relay fee. + filteredInputs := b.filterInputs(inputs) + + // Create clusters to group inputs based on their deadline height. + clusters := make(clusterGroup, len(filteredInputs)) + + // Iterate all the inputs and group them based on their specified + // deadline heights. + for _, input := range filteredInputs { + height := input.params.DeadlineHeight + cluster, ok := clusters[height] + if !ok { + cluster = make([]pendingInput, 0) + } + + cluster = append(cluster, *input) + clusters[height] = cluster + } + + // Now that we have the clusters, we can create the input sets. + // + // NOTE: cannot pre-allocate the slice since we don't know the number + // of input sets in advance. + inputSets := make([]InputSet, 0) + for _, cluster := range clusters { + // Sort the inputs by their economical value. + sortedInputs := b.sortInputs(cluster) + + // Create input sets from the cluster. + sets := b.createInputSets(sortedInputs) + inputSets = append(inputSets, sets...) + } + + return inputSets +} + +// createInputSet takes a set of inputs which share the same deadline height +// and turns them into a list of `InputSet`, each set is then used to create a +// sweep transaction. +func (b *BudgetAggregator) createInputSets(inputs []pendingInput) []InputSet { + // sets holds the InputSets that we will return. + sets := make([]InputSet, 0) + + // Copy the inputs to a new slice so we can modify it. + remainingInputs := make([]pendingInput, len(inputs)) + copy(remainingInputs, inputs) + + // If the number of inputs is greater than the max inputs allowed, we + // will split them into smaller clusters. + for uint32(len(remainingInputs)) > b.maxInputs { + log.Tracef("Cluster has %v inputs, max is %v, dividing...", + len(inputs), b.maxInputs) + + // Copy the inputs to be put into the new set, and update the + // remaining inputs by removing currentInputs. + currentInputs := make([]pendingInput, b.maxInputs) + copy(currentInputs, remainingInputs[:b.maxInputs]) + remainingInputs = remainingInputs[b.maxInputs:] + + // Create an InputSet using the max allowed number of inputs. + set, err := NewBudgetInputSet(currentInputs) + if err != nil { + log.Errorf("unable to create input set: %v", err) + + continue + } + + sets = append(sets, set) + } + + // Create an InputSet from the remaining inputs. + if len(remainingInputs) > 0 { + set, err := NewBudgetInputSet(remainingInputs) + if err != nil { + log.Errorf("unable to create input set: %v", err) + return nil + } + + sets = append(sets, set) + } + + return sets +} + +// filterInputs filters out inputs that have a budget below the min relay fee +// or have a required output that's below the dust. +func (b *BudgetAggregator) filterInputs(inputs pendingInputs) pendingInputs { + // Get the current min relay fee for this round. + minFeeRate := b.estimator.RelayFeePerKW() + + // filterInputs stores a map of inputs that has a budget that at least + // can pay the minimal fee. + filteredInputs := make(pendingInputs, len(inputs)) + + // Iterate all the inputs and filter out the ones whose budget cannot + // cover the min fee. + for _, pi := range inputs { + op := pi.OutPoint() + + // Get the size and skip if there's an error. + size, _, err := pi.WitnessType().SizeUpperBound() + if err != nil { + log.Warnf("Skipped input=%v: cannot get its size: %v", + op, err) + + continue + } + + // Skip inputs that has too little budget. + minFee := minFeeRate.FeeForWeight(int64(size)) + if pi.params.Budget < minFee { + log.Warnf("Skipped input=%v: has budget=%v, but the "+ + "min fee requires %v", op, pi.params.Budget, + minFee) + + continue + } + + // If the input comes with a required tx out that is below + // dust, we won't add it. + // + // NOTE: only HtlcSecondLevelAnchorInput returns non-nil + // RequiredTxOut. + reqOut := pi.RequiredTxOut() + if reqOut != nil { + if isDustOutput(reqOut) { + log.Errorf("Rejected input=%v due to dust "+ + "required output=%v", op, reqOut.Value) + + continue + } + } + + filteredInputs[*op] = pi + } + + return filteredInputs +} + +// sortInputs sorts the inputs based on their economical value. +// +// NOTE: besides the forced inputs, the sorting won't make any difference +// because all the inputs are added to the same set. The exception is when the +// number of inputs exceeds the maxInputs limit, it requires us to split them +// into smaller clusters. In that case, the sorting will make a difference as +// the budgets of the clusters will be different. +func (b *BudgetAggregator) sortInputs(inputs []pendingInput) []pendingInput { + // sortedInputs is the final list of inputs sorted by their economical + // value. + sortedInputs := make([]pendingInput, 0, len(inputs)) + + // Copy the inputs. + sortedInputs = append(sortedInputs, inputs...) + + // Sort the inputs based on their budgets. + // + // NOTE: We can implement more sophisticated algorithm as the budget + // left is a function f(minFeeRate, size) = b1 - s1 * r > b2 - s2 * r, + // where b1 and b2 are budgets, s1 and s2 are sizes of the inputs. + sort.Slice(sortedInputs, func(i, j int) bool { + left := sortedInputs[i].params.Budget + right := sortedInputs[j].params.Budget + + // Make sure forced inputs are always put in the front. + leftForce := sortedInputs[i].params.Force + rightForce := sortedInputs[j].params.Force + + // If both are forced inputs, we return the one with the higher + // budget. If neither are forced inputs, we also return the one + // with the higher budget. + if leftForce == rightForce { + return left > right + } + + // Otherwise, it's either the left or the right is forced. We + // can simply return `leftForce` here as, if it's true, the + // left is forced and should be put in the front. Otherwise, + // the right is forced and should be put in the front. + return leftForce + }) + + return sortedInputs +} + +// isDustOutput checks if the given output is considered as dust. +func isDustOutput(output *wire.TxOut) bool { + // Fetch the dust limit for this output. + dustLimit := lnwallet.DustLimitForSize(len(output.PkScript)) + + // If the output is below the dust limit, we consider it dust. + return btcutil.Amount(output.Value) < dustLimit +} diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go index 2058464ad2..bee1db5293 100644 --- a/sweep/aggregator_test.go +++ b/sweep/aggregator_test.go @@ -1,14 +1,17 @@ package sweep import ( + "bytes" "errors" "reflect" "sort" "testing" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/require" @@ -421,3 +424,510 @@ func TestClusterByLockTime(t *testing.T) { }) } } + +// TestBudgetAggregatorFilterInputs checks that inputs with low budget are +// filtered out. +func TestBudgetAggregatorFilterInputs(t *testing.T) { + t.Parallel() + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + defer estimator.AssertExpectations(t) + + // Create a mock WitnessType that always return an error when trying to + // get its size upper bound. + wtErr := &input.MockWitnessType{} + defer wtErr.AssertExpectations(t) + + // Mock the `SizeUpperBound` method to return an error exactly once. + dummyErr := errors.New("dummy error") + wtErr.On("SizeUpperBound").Return(0, false, dummyErr).Once() + + // Create a mock WitnessType that gives the size. + wt := &input.MockWitnessType{} + defer wt.AssertExpectations(t) + + // Mock the `SizeUpperBound` method to return the size four times. + const wtSize = 100 + wt.On("SizeUpperBound").Return(wtSize, true, nil).Times(4) + + // Create a mock input that will be filtered out due to error. + inpErr := &input.MockInput{} + defer inpErr.AssertExpectations(t) + + // Mock the `WitnessType` method to return the erroring witness type. + inpErr.On("WitnessType").Return(wtErr).Once() + + // Mock the `OutPoint` method to return a unique outpoint. + opErr := wire.OutPoint{Hash: chainhash.Hash{1}} + inpErr.On("OutPoint").Return(&opErr).Once() + + // Mock the estimator to return a constant fee rate. + const minFeeRate = chainfee.SatPerKWeight(1000) + estimator.On("RelayFeePerKW").Return(minFeeRate).Once() + + var ( + // Define three budget values, one below the min fee rate, one + // above and one equal to it. + budgetLow = minFeeRate.FeeForWeight(wtSize) - 1 + budgetEqual = minFeeRate.FeeForWeight(wtSize) + budgetHigh = minFeeRate.FeeForWeight(wtSize) + 1 + + // Define three outpoints with different budget values. + opLow = wire.OutPoint{Hash: chainhash.Hash{2}} + opEqual = wire.OutPoint{Hash: chainhash.Hash{3}} + opHigh = wire.OutPoint{Hash: chainhash.Hash{4}} + + // Define an outpoint that has a dust required output. + opDust = wire.OutPoint{Hash: chainhash.Hash{5}} + ) + + // Create three mock inputs. + inpLow := &input.MockInput{} + defer inpLow.AssertExpectations(t) + + inpEqual := &input.MockInput{} + defer inpEqual.AssertExpectations(t) + + inpHigh := &input.MockInput{} + defer inpHigh.AssertExpectations(t) + + inpDust := &input.MockInput{} + defer inpDust.AssertExpectations(t) + + // Mock the `WitnessType` method to return the witness type. + inpLow.On("WitnessType").Return(wt) + inpEqual.On("WitnessType").Return(wt) + inpHigh.On("WitnessType").Return(wt) + inpDust.On("WitnessType").Return(wt) + + // Mock the `OutPoint` method to return the unique outpoint. + inpLow.On("OutPoint").Return(&opLow) + inpEqual.On("OutPoint").Return(&opEqual) + inpHigh.On("OutPoint").Return(&opHigh) + inpDust.On("OutPoint").Return(&opDust) + + // Mock the `RequiredTxOut` to return nil. + inpEqual.On("RequiredTxOut").Return(nil) + inpHigh.On("RequiredTxOut").Return(nil) + + // Mock the dust required output. + inpDust.On("RequiredTxOut").Return(&wire.TxOut{ + Value: 0, + PkScript: bytes.Repeat([]byte{0}, input.P2WSHSize), + }) + + // Create testing pending inputs. + inputs := pendingInputs{ + // The first input will be filtered out due to the error. + opErr: &pendingInput{ + Input: inpErr, + }, + + // The second input will be filtered out due to the budget. + opLow: &pendingInput{ + Input: inpLow, + params: Params{Budget: budgetLow}, + }, + + // The third input will be included. + opEqual: &pendingInput{ + Input: inpEqual, + params: Params{Budget: budgetEqual}, + }, + + // The fourth input will be included. + opHigh: &pendingInput{ + Input: inpHigh, + params: Params{Budget: budgetHigh}, + }, + + // The fifth input will be filtered out due to the dust + // required. + opDust: &pendingInput{ + Input: inpDust, + params: Params{Budget: budgetHigh}, + }, + } + + // Init the budget aggregator with the mocked estimator and zero max + // num of inputs. + b := NewBudgetAggregator(estimator, 0) + + // Call the method under test. + result := b.filterInputs(inputs) + + // Validate the expected inputs are returned. + require.Len(t, result, 2) + + // We expect only the inputs with budget equal or above the min fee to + // be included. + require.Contains(t, result, opEqual) + require.Contains(t, result, opHigh) +} + +// TestBudgetAggregatorSortInputs checks that inputs are sorted by based on +// their budgets and force flag. +func TestBudgetAggregatorSortInputs(t *testing.T) { + t.Parallel() + + var ( + // Create two budgets. + budgetLow = btcutil.Amount(1000) + budgetHight = budgetLow + btcutil.Amount(1000) + ) + + // Create an input with the low budget but forced. + inputLowForce := pendingInput{ + params: Params{ + Budget: budgetLow, + Force: true, + }, + } + + // Create an input with the low budget. + inputLow := pendingInput{ + params: Params{ + Budget: budgetLow, + }, + } + + // Create an input with the high budget and forced. + inputHighForce := pendingInput{ + params: Params{ + Budget: budgetHight, + Force: true, + }, + } + + // Create an input with the high budget. + inputHigh := pendingInput{ + params: Params{ + Budget: budgetHight, + }, + } + + // Create a testing pending inputs. + inputs := []pendingInput{ + inputLowForce, + inputLow, + inputHighForce, + inputHigh, + } + + // Init the budget aggregator with zero max num of inputs. + b := NewBudgetAggregator(nil, 0) + + // Call the method under test. + result := b.sortInputs(inputs) + require.Len(t, result, 4) + + // The first input should be the forced input with the high budget. + require.Equal(t, inputHighForce, result[0]) + + // The second input should be the forced input with the low budget. + require.Equal(t, inputLowForce, result[1]) + + // The third input should be the input with the high budget. + require.Equal(t, inputHigh, result[2]) + + // The fourth input should be the input with the low budget. + require.Equal(t, inputLow, result[3]) +} + +// TestBudgetAggregatorCreateInputSets checks that the budget aggregator +// creates input sets when the number of inputs exceeds the max number +// configed. +func TestBudgetAggregatorCreateInputSets(t *testing.T) { + t.Parallel() + + // Create mocks input that doesn't have required outputs. + mockInput1 := &input.MockInput{} + defer mockInput1.AssertExpectations(t) + mockInput2 := &input.MockInput{} + defer mockInput2.AssertExpectations(t) + mockInput3 := &input.MockInput{} + defer mockInput3.AssertExpectations(t) + mockInput4 := &input.MockInput{} + defer mockInput4.AssertExpectations(t) + + // Create testing pending inputs. + pi1 := pendingInput{ + Input: mockInput1, + params: Params{ + DeadlineHeight: fn.Some(int32(1)), + }, + } + pi2 := pendingInput{ + Input: mockInput2, + params: Params{ + DeadlineHeight: fn.Some(int32(1)), + }, + } + pi3 := pendingInput{ + Input: mockInput3, + params: Params{ + DeadlineHeight: fn.Some(int32(1)), + }, + } + pi4 := pendingInput{ + Input: mockInput4, + params: Params{ + // This input has a deadline height that is different + // from the other inputs. When grouped with other + // inputs, it will cause an error to be returned. + DeadlineHeight: fn.Some(int32(2)), + }, + } + + // Create a budget aggregator with max number of inputs set to 2. + b := NewBudgetAggregator(nil, 2) + + // Create test cases. + testCases := []struct { + name string + inputs []pendingInput + setupMock func() + expectedNumSets int + }{ + { + // When the number of inputs is below the max, a single + // input set is returned. + name: "num inputs below max", + inputs: []pendingInput{pi1}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + }, + expectedNumSets: 1, + }, + { + // When the number of inputs is equal to the max, a + // single input set is returned. + name: "num inputs equal to max", + inputs: []pendingInput{pi1, pi2}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput2.On("WitnessType").Return( + input.CommitmentAnchor) + + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput2.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{2}}) + }, + expectedNumSets: 1, + }, + { + // When the number of inputs is above the max, multiple + // input sets are returned. + name: "num inputs above max", + inputs: []pendingInput{pi1, pi2, pi3}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput2.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput3.On("WitnessType").Return( + input.CommitmentAnchor) + + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput2.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{2}}) + mockInput3.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{3}}) + }, + expectedNumSets: 2, + }, + { + // When the number of inputs is above the max, but an + // error is returned from creating the first set, it + // shouldn't affect the remaining inputs. + name: "num inputs above max with error", + inputs: []pendingInput{pi1, pi4, pi3}, + setupMock: func() { + // Mock methods used in loggings. + mockInput1.On("WitnessType").Return( + input.CommitmentAnchor) + mockInput3.On("WitnessType").Return( + input.CommitmentAnchor) + + mockInput1.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput3.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{3}}) + mockInput4.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{2}}) + }, + expectedNumSets: 1, + }, + } + + // Iterate over the test cases. + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Setup the mocks. + tc.setupMock() + + // Call the method under test. + result := b.createInputSets(tc.inputs) + + // Validate the expected number of input sets are + // returned. + require.Len(t, result, tc.expectedNumSets) + }) + } +} + +// TestBudgetInputSetClusterInputs checks that the budget aggregator clusters +// inputs into input sets based on their deadline heights. +func TestBudgetInputSetClusterInputs(t *testing.T) { + t.Parallel() + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + defer estimator.AssertExpectations(t) + + // Create a mock WitnessType that gives the size. + wt := &input.MockWitnessType{} + defer wt.AssertExpectations(t) + + // Mock the `SizeUpperBound` method to return the size six times since + // we are using nine inputs. + const wtSize = 100 + wt.On("SizeUpperBound").Return(wtSize, true, nil).Times(9) + wt.On("String").Return("mock witness type") + + // Mock the estimator to return a constant fee rate. + const minFeeRate = chainfee.SatPerKWeight(1000) + estimator.On("RelayFeePerKW").Return(minFeeRate).Once() + + var ( + // Define two budget values, one below the min fee rate and one + // above it. + budgetLow = minFeeRate.FeeForWeight(wtSize) - 1 + budgetHigh = minFeeRate.FeeForWeight(wtSize) + 1 + + // Create three deadline heights, which means there are three + // groups of inputs to be expected. + deadlineNone = fn.None[int32]() + deadline1 = fn.Some(int32(1)) + deadline2 = fn.Some(int32(2)) + ) + + // Create testing pending inputs. + inputs := make(pendingInputs) + + // For each deadline height, create two inputs with different budgets, + // one below the min fee rate and one above it. We should see the lower + // one being filtered out. + for i, deadline := range []fn.Option[int32]{ + deadlineNone, deadline1, deadline2, + } { + // Define three outpoints. + opLow := wire.OutPoint{ + Hash: chainhash.Hash{byte(i)}, + Index: uint32(i), + } + opHigh1 := wire.OutPoint{ + Hash: chainhash.Hash{byte(i + 1000)}, + Index: uint32(i + 1000), + } + opHigh2 := wire.OutPoint{ + Hash: chainhash.Hash{byte(i + 2000)}, + Index: uint32(i + 2000), + } + + // Create mock inputs. + inpLow := &input.MockInput{} + defer inpLow.AssertExpectations(t) + + inpHigh1 := &input.MockInput{} + defer inpHigh1.AssertExpectations(t) + + inpHigh2 := &input.MockInput{} + defer inpHigh2.AssertExpectations(t) + + // Mock the `OutPoint` method to return the unique outpoint. + // + // We expect the low budget input to call this method once in + // `filterInputs`. + inpLow.On("OutPoint").Return(&opLow).Once() + + // We expect the high budget input to call this method three + // times, one in `filterInputs` and one in `createInputSet`, + // and one in `NewBudgetInputSet`. + inpHigh1.On("OutPoint").Return(&opHigh1).Times(3) + inpHigh2.On("OutPoint").Return(&opHigh2).Times(3) + + // Mock the `WitnessType` method to return the witness type. + inpLow.On("WitnessType").Return(wt) + inpHigh1.On("WitnessType").Return(wt) + inpHigh2.On("WitnessType").Return(wt) + + // Mock the `RequiredTxOut` to return nil. + inpHigh1.On("RequiredTxOut").Return(nil) + inpHigh2.On("RequiredTxOut").Return(nil) + + // Add the low input, which should be filtered out. + inputs[opLow] = &pendingInput{ + Input: inpLow, + params: Params{ + Budget: budgetLow, + DeadlineHeight: deadline, + }, + } + + // Add the high inputs, which should be included. + inputs[opHigh1] = &pendingInput{ + Input: inpHigh1, + params: Params{ + Budget: budgetHigh, + DeadlineHeight: deadline, + }, + } + inputs[opHigh2] = &pendingInput{ + Input: inpHigh2, + params: Params{ + Budget: budgetHigh, + DeadlineHeight: deadline, + }, + } + } + + // Create a budget aggregator with a max number of inputs set to 100. + b := NewBudgetAggregator(estimator, DefaultMaxInputsPerTx) + + // Call the method under test. + result := b.ClusterInputs(inputs) + + // We expect three input sets to be returned, one for each deadline. + require.Len(t, result, 3) + + // Check each input set has exactly two inputs. + deadlines := make(map[fn.Option[int32]]struct{}) + for _, set := range result { + // We expect two inputs in each set. + require.Len(t, set.Inputs(), 2) + + // We expect each set to have the expected budget. + require.Equal(t, budgetHigh*2, set.Budget()) + + // Save the deadlines. + deadlines[set.DeadlineHeight()] = struct{}{} + } + + // We expect to see all three deadlines. + require.Contains(t, deadlines, deadlineNone) + require.Contains(t, deadlines, deadline1) + require.Contains(t, deadlines, deadline2) +} diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go new file mode 100644 index 0000000000..58a7f8b454 --- /dev/null +++ b/sweep/fee_bumper.go @@ -0,0 +1,953 @@ +package sweep + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/rpcclient" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcwallet/chain" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/labels" + "github.com/lightningnetwork/lnd/lnutils" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +var ( + // ErrInvalidBumpResult is returned when the bump result is invalid. + ErrInvalidBumpResult = errors.New("invalid bump result") + + // ErrNotEnoughBudget is returned when the fee bumper decides the + // current budget cannot cover the fee. + ErrNotEnoughBudget = errors.New("not enough budget") +) + +// Bumper defines an interface that can be used by other subsystems for fee +// bumping. +type Bumper interface { + // Broadcast is used to publish the tx created from the given inputs + // specified in the request. It handles the tx creation, broadcasts it, + // and monitors its confirmation status for potential fee bumping. It + // returns a chan that the caller can use to receive updates about the + // broadcast result and potential RBF attempts. + Broadcast(req *BumpRequest) (<-chan *BumpResult, error) +} + +// BumpEvent represents the event of a fee bumping attempt. +type BumpEvent uint8 + +const ( + // TxPublished is sent when the broadcast attempt is finished. + TxPublished BumpEvent = iota + + // TxFailed is sent when the broadcast attempt fails. + TxFailed + + // TxReplaced is sent when the original tx is replaced by a new one. + TxReplaced + + // TxConfirmed is sent when the tx is confirmed. + TxConfirmed + + // sentinalEvent is used to check if an event is unknown. + sentinalEvent +) + +// String returns a human-readable string for the event. +func (e BumpEvent) String() string { + switch e { + case TxPublished: + return "Published" + case TxFailed: + return "Failed" + case TxReplaced: + return "Replaced" + case TxConfirmed: + return "Confirmed" + default: + return "Unknown" + } +} + +// Unknown returns true if the event is unknown. +func (e BumpEvent) Unknown() bool { + return e >= sentinalEvent +} + +// BumpRequest is used by the caller to give the Bumper the necessary info to +// create and manage potential fee bumps for a set of inputs. +type BumpRequest struct { + // Budget givens the total amount that can be used as fees by these + // inputs. + Budget btcutil.Amount + + // Inputs is the set of inputs to sweep. + Inputs []input.Input + + // DeadlineHeight is the block height at which the tx should be + // confirmed. + DeadlineHeight int32 + + // DeliveryAddress is the script to send the change output to. + DeliveryAddress []byte + + // MaxFeeRate is the maximum fee rate that can be used for fee bumping. + MaxFeeRate chainfee.SatPerKWeight +} + +// MaxFeeRateAllowed returns the maximum fee rate allowed for the given +// request. It calculates the feerate using the supplied budget and the weight, +// compares it with the specified MaxFeeRate, and returns the smaller of the +// two. +func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) { + // Get the size of the sweep tx, which will be used to calculate the + // budget fee rate. + size, err := calcSweepTxWeight(r.Inputs, r.DeliveryAddress) + if err != nil { + return 0, err + } + + // Use the budget and MaxFeeRate to decide the max allowed fee rate. + // This is needed as, when the input has a large value and the user + // sets the budget to be proportional to the input value, the fee rate + // can be very high and we need to make sure it doesn't exceed the max + // fee rate. + maxFeeRateAllowed := chainfee.NewSatPerKWeight(r.Budget, size) + if maxFeeRateAllowed > r.MaxFeeRate { + log.Debugf("Budget feerate %v exceeds MaxFeeRate %v, use "+ + "MaxFeeRate instead", maxFeeRateAllowed, r.MaxFeeRate) + + return r.MaxFeeRate, nil + } + + log.Debugf("Budget feerate %v below MaxFeeRate %v, use budget feerate "+ + "instead", maxFeeRateAllowed, r.MaxFeeRate) + + return maxFeeRateAllowed, nil +} + +// calcSweepTxWeight calculates the weight of the sweep tx. It assumes a +// sweeping tx always has a single output(change). +func calcSweepTxWeight(inputs []input.Input, + outputPkScript []byte) (uint64, error) { + + // Use a const fee rate as we only use the weight estimator to + // calculate the size. + const feeRate = 1 + + // Initialize the tx weight estimator with, + // - nil outputs as we only have one single change output. + // - const fee rate as we don't care about the fees here. + // - 0 maxfeerate as we don't care about fees here. + // + // TODO(yy): we should refactor the weight estimator to not require a + // fee rate and max fee rate and make it a pure tx weight calculator. + _, estimator, err := getWeightEstimate( + inputs, nil, feeRate, 0, outputPkScript, + ) + if err != nil { + return 0, err + } + + return uint64(estimator.weight()), nil +} + +// BumpResult is used by the Bumper to send updates about the tx being +// broadcast. +type BumpResult struct { + // Event is the type of event that the result is for. + Event BumpEvent + + // Tx is the tx being broadcast. + Tx *wire.MsgTx + + // ReplacedTx is the old, replaced tx if a fee bump is attempted. + ReplacedTx *wire.MsgTx + + // FeeRate is the fee rate used for the new tx. + FeeRate chainfee.SatPerKWeight + + // Fee is the fee paid by the new tx. + Fee btcutil.Amount + + // Err is the error that occurred during the broadcast. + Err error + + // requestID is the ID of the request that created this record. + requestID uint64 +} + +// Validate validates the BumpResult so it's safe to use. +func (b *BumpResult) Validate() error { + // Every result must have a tx. + if b.Tx == nil { + return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult) + } + + // Every result must have a known event. + if b.Event.Unknown() { + return fmt.Errorf("%w: unknown event", ErrInvalidBumpResult) + } + + // If it's a replacing event, it must have a replaced tx. + if b.Event == TxReplaced && b.ReplacedTx == nil { + return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult) + } + + // If it's a failed event, it must have an error. + if b.Event == TxFailed && b.Err == nil { + return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) + } + + // If it's a confirmed event, it must have a fee rate and fee. + if b.Event == TxConfirmed && (b.FeeRate == 0 || b.Fee == 0) { + return fmt.Errorf("%w: missing fee rate or fee", + ErrInvalidBumpResult) + } + + return nil +} + +// TxPublisherConfig is the config used to create a new TxPublisher. +type TxPublisherConfig struct { + // Signer is used to create the tx signature. + Signer input.Signer + + // Wallet is used primarily to publish the tx. + Wallet Wallet + + // Estimator is used to estimate the fee rate for the new tx based on + // its deadline conf target. + Estimator chainfee.Estimator + + // Notifier is used to monitor the confirmation status of the tx. + Notifier chainntnfs.ChainNotifier +} + +// TxPublisher is an implementation of the Bumper interface. It utilizes the +// `testmempoolaccept` RPC to bump the fee of txns it created based on +// different fee function selected or configed by the caller. Its purpose is to +// take a list of inputs specified, and create a tx that spends them to a +// specified output. It will then monitor the confirmation status of the tx, +// and if it's not confirmed within a certain time frame, it will attempt to +// bump the fee of the tx by creating a new tx that spends the same inputs to +// the same output, but with a higher fee rate. It will continue to do this +// until the tx is confirmed or the fee rate reaches the maximum fee rate +// specified by the caller. +type TxPublisher struct { + wg sync.WaitGroup + + // cfg specifies the configuration of the TxPublisher. + cfg *TxPublisherConfig + + // currentHeight is the current block height. + currentHeight int32 + + // records is a map keyed by the requestCounter and the value is the tx + // being monitored. + records lnutils.SyncMap[uint64, *monitorRecord] + + // requestCounter is a monotonically increasing counter used to keep + // track of how many requests have been made. + requestCounter atomic.Uint64 + + // subscriberChans is a map keyed by the requestCounter, each item is + // the chan that the publisher sends the fee bump result to. + subscriberChans lnutils.SyncMap[uint64, chan *BumpResult] + + // quit is used to signal the publisher to stop. + quit chan struct{} +} + +// Compile-time constraint to ensure TxPublisher implements Bumper. +var _ Bumper = (*TxPublisher)(nil) + +// NewTxPublisher creates a new TxPublisher. +func NewTxPublisher(cfg TxPublisherConfig) *TxPublisher { + return &TxPublisher{ + cfg: &cfg, + records: lnutils.SyncMap[uint64, *monitorRecord]{}, + subscriberChans: lnutils.SyncMap[uint64, chan *BumpResult]{}, + quit: make(chan struct{}), + } +} + +// Broadcast is used to publish the tx created from the given inputs. It will, +// 1. init a fee function based on the given strategy. +// 2. create an RBF-compliant tx and monitor it for confirmation. +// 3. notify the initial broadcast result back to the caller. +// The initial broadcast is guaranteed to be RBF-compliant unless the budget +// specified cannot cover the fee. +// +// NOTE: part of the Bumper interface. +func (t *TxPublisher) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { + log.Tracef("Received broadcast request: %s", newLogClosure( + func() string { + return spew.Sdump(req) + })()) + + // Attempt an initial broadcast which is guaranteed to comply with the + // RBF rules. + result, err := t.initialBroadcast(req) + if err != nil { + log.Errorf("Initial broadcast failed: %v", err) + + return nil, err + } + + // Create a chan to send the result to the caller. + subscriber := make(chan *BumpResult, 1) + t.subscriberChans.Store(result.requestID, subscriber) + + // Send the initial broadcast result to the caller. + t.handleResult(result) + + return subscriber, nil +} + +// initialBroadcast initializes a fee function, creates an RBF-compliant tx and +// broadcasts it. +func (t *TxPublisher) initialBroadcast(req *BumpRequest) (*BumpResult, error) { + // Create a fee bumping algorithm to be used for future RBF. + feeAlgo, err := t.initializeFeeFunction(req) + if err != nil { + return nil, fmt.Errorf("init fee function: %w", err) + } + + // Create the initial tx to be broadcasted. This tx is guaranteed to + // comply with the RBF restrictions. + requestID, err := t.createRBFCompliantTx(req, feeAlgo) + if err != nil { + return nil, fmt.Errorf("create RBF-compliant tx: %w", err) + } + + // Broadcast the tx and return the monitored record. + result, err := t.broadcast(requestID) + if err != nil { + return nil, fmt.Errorf("broadcast sweep tx: %w", err) + } + + return result, nil +} + +// initializeFeeFunction initializes a fee function to be used for this request +// for future fee bumping. +func (t *TxPublisher) initializeFeeFunction( + req *BumpRequest) (FeeFunction, error) { + + // Get the max allowed feerate. + maxFeeRateAllowed, err := req.MaxFeeRateAllowed() + if err != nil { + return nil, err + } + + // Get the initial conf target. + confTarget := calcCurrentConfTarget(t.currentHeight, req.DeadlineHeight) + + // Initialize the fee function and return it. + // + // TODO(yy): return based on differet req.Strategy? + return NewLinearFeeFunction( + maxFeeRateAllowed, confTarget, t.cfg.Estimator, + ) +} + +// createRBFCompliantTx creates a tx that is compliant with RBF rules. It does +// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee +// and redo the process until the tx is valid, or return an error when non-RBF +// related errors occur or the budget has been used up. +func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, + f FeeFunction) (uint64, error) { + + for { + // Create a new tx with the given fee rate and check its + // mempool acceptance. + tx, fee, err := t.createAndCheckTx(req, f) + + switch { + case err == nil: + // The tx is valid, return the request ID. + requestID := t.storeRecord(tx, req, f, fee) + + log.Infof("Created tx %v for %v inputs: feerate=%v, "+ + "fee=%v, inputs=%v", tx.TxHash(), + len(req.Inputs), f.FeeRate(), fee, + inputTypeSummary(req.Inputs)) + + return requestID, nil + + // If the error indicates the fees paid is not enough, we will + // ask the fee function to increase the fee rate and retry. + case errors.Is(err, lnwallet.ErrMempoolFee): + // We should at least start with a feerate above the + // mempool min feerate, so if we get this error, it + // means something is wrong earlier in the pipeline. + log.Errorf("Current fee=%v, feerate=%v, %v", fee, + f.FeeRate(), err) + + fallthrough + + // We are not paying enough fees so we increase it. + case errors.Is(err, rpcclient.ErrInsufficientFee): + increased := false + + // Keep calling the fee function until the fee rate is + // increased or maxed out. + for !increased { + log.Debugf("Increasing fee for next round, "+ + "current fee=%v, feerate=%v", fee, + f.FeeRate()) + + // If the fee function tells us that we have + // used up the budget, we will return an error + // indicating this tx cannot be made. The + // sweeper should handle this error and try to + // cluster these inputs differetly. + increased, err = f.Increment() + if err != nil { + return 0, err + } + } + + // TODO(yy): suppose there's only one bad input, we can do a + // binary search to find out which input is causing this error + // by recreating a tx using half of the inputs and check its + // mempool acceptance. + default: + log.Debugf("Failed to create RBF-compliant tx: %v", err) + return 0, err + } + } +} + +// storeRecord stores the given record in the records map. +func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, + f FeeFunction, fee btcutil.Amount) uint64 { + + // Increase the request counter. + // + // NOTE: this is the only place where we increase the + // counter. + requestID := t.requestCounter.Add(1) + + // Register the record. + t.records.Store(requestID, &monitorRecord{ + tx: tx, + req: req, + feeFunction: f, + fee: fee, + }) + + return requestID +} + +// createAndCheckTx creates a tx based on the given inputs, change output +// script, and the fee rate. In addition, it validates the tx's mempool +// acceptance before returning a tx that can be published directly, along with +// its fee. +func (t *TxPublisher) createAndCheckTx(req *BumpRequest, f FeeFunction) ( + *wire.MsgTx, btcutil.Amount, error) { + + // Create the sweep tx with max fee rate of 0 as the fee function + // guarantees the fee rate used here won't exceed the max fee rate. + // + // TODO(yy): refactor this function to not require a max fee rate. + tx, fee, err := createSweepTx( + req.Inputs, nil, req.DeliveryAddress, uint32(t.currentHeight), + f.FeeRate(), 0, t.cfg.Signer, + ) + if err != nil { + return nil, 0, fmt.Errorf("create sweep tx: %w", err) + } + + // Sanity check the budget still covers the fee. + if fee > req.Budget { + return nil, 0, fmt.Errorf("%w: budget=%v, fee=%v", + ErrNotEnoughBudget, req.Budget, fee) + } + + // Validate the tx's mempool acceptance. + err = t.cfg.Wallet.CheckMempoolAcceptance(tx) + + // Exit early if the tx is valid. + if err == nil { + return tx, fee, nil + } + + // Print an error log if the chain backend doesn't support the mempool + // acceptance test RPC. + if errors.Is(err, rpcclient.ErrBackendVersion) { + log.Errorf("TestMempoolAccept not supported by backend, " + + "consider upgrading it to a newer version") + return tx, fee, nil + } + + // We are running on a backend that doesn't implement the RPC + // testmempoolaccept, eg, neutrino, so we'll skip the check. + if errors.Is(err, chain.ErrUnimplemented) { + log.Debug("Skipped testmempoolaccept due to not implemented") + return tx, fee, nil + } + + return nil, 0, err +} + +// broadcast takes a monitored tx and publishes it to the network. Prior to the +// broadcast, it will subscribe the tx's confirmation notification and attach +// the event channel to the record. Any broadcast-related errors will not be +// returned here, instead, they will be put inside the `BumpResult` and +// returned to the caller. +func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { + // Get the record being monitored. + record, ok := t.records.Load(requestID) + if !ok { + return nil, fmt.Errorf("tx record %v not found", requestID) + } + + txid := record.tx.TxHash() + + tx := record.tx + log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", + txid, len(tx.TxIn), t.currentHeight) + + // Set the event, and change it to TxFailed if the wallet fails to + // publish it. + event := TxPublished + + // Publish the sweeping tx with customized label. If the publish fails, + // this error will be saved in the `BumpResult` and it will be removed + // from being monitored. + err := t.cfg.Wallet.PublishTransaction( + tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), + ) + if err != nil { + // NOTE: we decide to attach this error to the result instead + // of returning it here because by the time the tx reaches + // here, it should have passed the mempool acceptance check. If + // it still fails to be broadcast, it's likely a non-RBF + // related error happened. So we send this error back to the + // caller so that it can handle it properly. + // + // TODO(yy): find out which input is causing the failure. + log.Errorf("Failed to publish tx %v: %v", txid, err) + event = TxFailed + } + + result := &BumpResult{ + Event: event, + Tx: record.tx, + Fee: record.fee, + FeeRate: record.feeFunction.FeeRate(), + Err: err, + requestID: requestID, + } + + return result, nil +} + +// notifyResult sends the result to the resultChan specified by the requestID. +// This channel is expected to be read by the caller. +func (t *TxPublisher) notifyResult(result *BumpResult) { + id := result.requestID + subscriber, ok := t.subscriberChans.Load(id) + if !ok { + log.Errorf("Result chan for id=%v not found", id) + return + } + + log.Debugf("Sending result for requestID=%v, tx=%v", id, + result.Tx.TxHash()) + + select { + // Send the result to the subscriber. + // + // TODO(yy): Add timeout in case it's blocking? + case subscriber <- result: + case <-t.quit: + log.Debug("Fee bumper stopped") + } +} + +// removeResult removes the tracking of the result if the result contains a +// non-nil error, or the tx is confirmed, the record will be removed from the +// maps. +func (t *TxPublisher) removeResult(result *BumpResult) { + id := result.requestID + + // Remove the record from the maps if there's an error. This means this + // tx has failed its broadcast and cannot be retried. There are two + // cases, + // - when the budget cannot cover the fee. + // - when a non-RBF related error occurs. + switch result.Event { + case TxFailed: + log.Errorf("Removing monitor record=%v, tx=%v, due to err: %v", + id, result.Tx.TxHash(), result.Err) + + case TxConfirmed: + // Remove the record is the tx is confirmed. + log.Debugf("Removing confirmed monitor record=%v, tx=%v", id, + result.Tx.TxHash()) + + // Do nothing if it's neither failed or confirmed. + default: + log.Tracef("Skipping record removal for id=%v, event=%v", id, + result.Event) + + return + } + + t.records.Delete(id) + t.subscriberChans.Delete(id) +} + +// handleResult handles the result of a tx broadcast. It will notify the +// subscriber and remove the record if the tx is confirmed or failed to be +// broadcast. +func (t *TxPublisher) handleResult(result *BumpResult) { + // Notify the subscriber. + t.notifyResult(result) + + // Remove the record if it's failed or confirmed. + t.removeResult(result) +} + +// monitorRecord is used to keep track of the tx being monitored by the +// publisher internally. +type monitorRecord struct { + // tx is the tx being monitored. + tx *wire.MsgTx + + // req is the original request. + req *BumpRequest + + // feeFunction is the fee bumping algorithm used by the publisher. + feeFunction FeeFunction + + // fee is the fee paid by the tx. + fee btcutil.Amount +} + +// Start starts the publisher by subscribing to block epoch updates and kicking +// off the monitor loop. +func (t *TxPublisher) Start() error { + log.Info("TxPublisher starting...") + defer log.Debugf("TxPublisher started") + + blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + t.wg.Add(1) + go t.monitor(blockEvent) + + return nil +} + +// Stop stops the publisher and waits for the monitor loop to exit. +func (t *TxPublisher) Stop() { + log.Info("TxPublisher stopping...") + defer log.Debugf("TxPublisher stopped") + + close(t.quit) + + t.wg.Wait() +} + +// monitor is the main loop driven by new blocks. Whevenr a new block arrives, +// it will examine all the txns being monitored, and check if any of them needs +// to be bumped. If so, it will attempt to bump the fee of the tx. +// +// NOTE: Must be run as a goroutine. +func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { + defer blockEvent.Cancel() + defer t.wg.Done() + + for { + select { + case epoch, ok := <-blockEvent.Epochs: + if !ok { + // We should stop the publisher before stopping + // the chain service. Otherwise it indicates an + // error. + log.Error("Block epoch channel closed, exit " + + "monitor") + + return + } + + log.Debugf("TxPublisher received new block: %v", + epoch.Height) + + // Update the best known height for the publisher. + t.currentHeight = epoch.Height + + // Check all monitored txns to see if any of them needs + // to be bumped. + t.processRecords() + + case <-t.quit: + log.Debug("Fee bumper stopped, exit monitor") + return + } + } +} + +// processRecords checks all the txns being monitored, and checks if any of +// them needs to be bumped. If so, it will attempt to bump the fee of the tx. +func (t *TxPublisher) processRecords() { + // confirmedRecords stores a map of the records which have been + // confirmed. + confirmedRecords := make(map[uint64]*monitorRecord) + + // feeBumpRecords stores a map of the records which need to be bumped. + feeBumpRecords := make(map[uint64]*monitorRecord) + + // visitor is a helper closure that visits each record and divides them + // into two groups. + visitor := func(requestID uint64, r *monitorRecord) error { + log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, + r.tx.TxHash()) + + // If the tx is already confirmed, we can stop monitoring it. + if t.isConfirmed(r.tx.TxHash()) { + confirmedRecords[requestID] = r + + // Move to the next record. + return nil + } + + feeBumpRecords[requestID] = r + + // Return nil to move to the next record. + return nil + } + + // Iterate through all the records and divide them into two groups. + t.records.ForEach(visitor) + + // For records that are confirmed, we'll notify the caller about this + // result. + for requestID, r := range confirmedRecords { + rec := r + + log.Debugf("Tx=%v is confirmed", r.tx.TxHash()) + t.wg.Add(1) + go t.handleTxConfirmed(rec, requestID) + } + + // Get the current height to be used in the following goroutines. + currentHeight := t.currentHeight + + // For records that are not confirmed, we perform a fee bump if needed. + for requestID, r := range feeBumpRecords { + rec := r + + log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash()) + t.wg.Add(1) + go t.handleFeeBumpTx(requestID, rec, currentHeight) + } +} + +// handleTxConfirmed is called when a monitored tx is confirmed. It will +// notify the subscriber then remove the record from the maps . +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { + defer t.wg.Done() + + // Create a result that will be sent to the resultChan which is + // listened by the caller. + result := &BumpResult{ + Event: TxConfirmed, + Tx: r.tx, + requestID: requestID, + Fee: r.fee, + FeeRate: r.feeFunction.FeeRate(), + } + + // Notify that this tx is confirmed and remove the record from the map. + t.handleResult(result) +} + +// handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will +// attempt to bump the fee of the tx. +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord, + currentHeight int32) { + + defer t.wg.Done() + + oldTxid := r.tx.TxHash() + + // Get the current conf target for this record. + confTarget := calcCurrentConfTarget(currentHeight, r.req.DeadlineHeight) + + // Ask the fee function whether a bump is needed. We expect the fee + // function to increase its returned fee rate after calling this + // method. + increased, err := r.feeFunction.IncreaseFeeRate(confTarget) + if err != nil { + // TODO(yy): send this error back to the sweeper so it can + // re-group the inputs? + log.Errorf("Failed to increase fee rate for tx %v at "+ + "height=%v: %v", oldTxid, t.currentHeight, err) + + return + } + + // If the fee rate was not increased, there's no need to bump the fee. + if !increased { + log.Tracef("Skip bumping tx %v at height=%v", oldTxid, + t.currentHeight) + + return + } + + // The fee function now has a new fee rate, we will use it to bump the + // fee of the tx. + resultOpt := t.createAndPublishTx(requestID, r) + + // If there's a result, we will notify the caller about the result. + resultOpt.WhenSome(func(result BumpResult) { + // Notify the new result. + t.handleResult(&result) + }) +} + +// createAndPublishTx creates a new tx with a higher fee rate and publishes it +// to the network. It will update the record with the new tx and fee rate if +// successfully created, and return the result when published successfully. +func (t *TxPublisher) createAndPublishTx(requestID uint64, + r *monitorRecord) fn.Option[BumpResult] { + + // Fetch the old tx. + oldTx := r.tx + + // Create a new tx with the new fee rate. + // + // NOTE: The fee function is expected to have increased its returned + // fee rate after calling the SkipFeeBump method. So we can use it + // directly here. + tx, fee, err := t.createAndCheckTx(r.req, r.feeFunction) + + // If the error is fee related, we will return an error and let the fee + // bumper retry it at next block. + // + // NOTE: we can check the RBF error here and ask the fee function to + // recalculate the fee rate. However, this would defeat the purpose of + // using a deadline based fee function: + // - if the deadline is far away, there's no rush to RBF the tx. + // - if the deadline is close, we expect the fee function to give us a + // higher fee rate. If the fee rate cannot satisfy the RBF rules, it + // means the budget is not enough. + if errors.Is(err, rpcclient.ErrInsufficientFee) || + errors.Is(err, lnwallet.ErrMempoolFee) { + + log.Debugf("Failed to bump tx %v: %v", oldTx.TxHash(), err) + return fn.None[BumpResult]() + } + + // If the error is not fee related, we will return a `TxFailed` event + // so this input can be retried. + if err != nil { + // If the tx doesn't not have enought budget, we will return a + // result so the sweeper can handle it by re-clustering the + // utxos. + if errors.Is(err, ErrNotEnoughBudget) { + log.Warnf("Fail to fee bump tx %v: %v", oldTx.TxHash(), + err) + } else { + // Otherwise, an unexpected error occurred, we will + // fail the tx and let the sweeper retry the whole + // process. + log.Errorf("Failed to bump tx %v: %v", oldTx.TxHash(), + err) + } + + return fn.Some(BumpResult{ + Event: TxFailed, + Tx: oldTx, + Err: err, + requestID: requestID, + }) + } + + // The tx has been created without any errors, we now register a new + // record by overwriting the same requestID. + t.records.Store(requestID, &monitorRecord{ + tx: tx, + req: r.req, + feeFunction: r.feeFunction, + fee: fee, + }) + + // Attempt to broadcast this new tx. + result, err := t.broadcast(requestID) + if err != nil { + log.Infof("Failed to broadcast replacement tx %v: %v", + tx.TxHash(), err) + + return fn.None[BumpResult]() + } + + // A successful replacement tx is created, attach the old tx. + result.ReplacedTx = oldTx + + // If the new tx failed to be published, we will return the result so + // the caller can handle it. + if result.Event == TxFailed { + return fn.Some(*result) + } + + log.Infof("Replaced tx=%v with new tx=%v", oldTx.TxHash(), tx.TxHash()) + + // Otherwise, it's a successful RBF, set the event and return. + result.Event = TxReplaced + + return fn.Some(*result) +} + +// isConfirmed checks the btcwallet to see whether the tx is confirmed. +func (t *TxPublisher) isConfirmed(txid chainhash.Hash) bool { + details, err := t.cfg.Wallet.GetTransactionDetails(&txid) + if err != nil { + log.Warnf("Failed to get tx details for %v: %v", txid, err) + return false + } + + return details.NumConfirmations > 0 +} + +// calcCurrentConfTarget calculates the current confirmation target based on +// the deadline height. The conf target is capped at 0 if the deadline has +// already been past. +func calcCurrentConfTarget(currentHeight, deadline int32) uint32 { + var confTarget uint32 + + // Calculate how many blocks left until the deadline. + deadlineDelta := deadline - currentHeight + + // If we are already past the deadline, we will set the conf target to + // be 1. + if deadlineDelta <= 0 { + log.Warnf("Deadline is %d blocks behind current height %v", + -deadlineDelta, currentHeight) + + confTarget = 1 + } else { + confTarget = uint32(deadlineDelta) + } + + return confTarget +} diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go new file mode 100644 index 0000000000..5f031a9bff --- /dev/null +++ b/sweep/fee_bumper_test.go @@ -0,0 +1,1417 @@ +package sweep + +import ( + "fmt" + "testing" + "time" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/rpcclient" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var ( + // Create a taproot change script. + changePkScript = []byte{ + 0x51, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } +) + +// TestBumpResultValidate tests the validate method of the BumpResult struct. +func TestBumpResultValidate(t *testing.T) { + t.Parallel() + + // An empty result will give an error. + b := BumpResult{} + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // Unknown event type will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: sentinalEvent, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A replacing event without a new tx will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxReplaced, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A failed event without a failure reason will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxFailed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A confirmed event without fee info will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxConfirmed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // Test a valid result. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxPublished, + } + require.NoError(t, b.Validate()) +} + +// TestCalcSweepTxWeight checks that the weight of the sweep tx is calculated +// correctly. +func TestCalcSweepTxWeight(t *testing.T) { + t.Parallel() + + // Create an input. + inp := createTestInput(100, input.WitnessKeyHash) + + // Use a wrong change script to test the error case. + weight, err := calcSweepTxWeight([]input.Input{&inp}, []byte{0}) + require.Error(t, err) + require.Zero(t, weight) + + // Use a correct change script to test the success case. + weight, err = calcSweepTxWeight([]input.Input{&inp}, changePkScript) + require.NoError(t, err) + + // BaseTxSize 8 bytes + // InputSize 1+41 bytes + // One P2TROutputSize 1+43 bytes + // One P2WKHWitnessSize 2+109 bytes + // Total weight = (8+42+44) * 4 + 111 = 487 + require.EqualValuesf(t, 487, weight, "unexpected weight %v", weight) +} + +// TestBumpRequestMaxFeeRateAllowed tests the max fee rate allowed for a bump +// request. +func TestBumpRequestMaxFeeRateAllowed(t *testing.T) { + t.Parallel() + + // Create a test input. + inp := createTestInput(100, input.WitnessKeyHash) + + // The weight is 487. + weight, err := calcSweepTxWeight([]input.Input{&inp}, changePkScript) + require.NoError(t, err) + + // Define a test budget and calculates its fee rate. + budget := btcutil.Amount(1000) + budgetFeeRate := chainfee.NewSatPerKWeight(budget, weight) + + testCases := []struct { + name string + req *BumpRequest + expectedMaxFeeRate chainfee.SatPerKWeight + expectedErr bool + }{ + { + // Use a wrong change script to test the error case. + name: "error calc weight", + req: &BumpRequest{ + DeliveryAddress: []byte{1}, + }, + expectedMaxFeeRate: 0, + expectedErr: true, + }, + { + // When the budget cannot give a fee rate that matches + // the supplied MaxFeeRate, the max allowed feerate is + // capped by the budget. + name: "use budget as max fee rate", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: budget, + MaxFeeRate: budgetFeeRate + 1, + }, + expectedMaxFeeRate: budgetFeeRate, + }, + { + // When the budget can give a fee rate that matches the + // supplied MaxFeeRate, the max allowed feerate is + // capped by the MaxFeeRate. + name: "use config as max fee rate", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: budget, + MaxFeeRate: budgetFeeRate - 1, + }, + expectedMaxFeeRate: budgetFeeRate - 1, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Check the method under test. + maxFeeRate, err := tc.req.MaxFeeRateAllowed() + + // If we expect an error, check the error is returned + // and the feerate is empty. + if tc.expectedErr { + require.Error(t, err) + require.Zero(t, maxFeeRate) + + return + } + + // Otherwise, check the max fee rate is as expected. + require.NoError(t, err) + require.Equal(t, tc.expectedMaxFeeRate, maxFeeRate) + }) + } +} + +// TestCalcCurrentConfTarget checks that the current confirmation target is +// calculated correctly. +func TestCalcCurrentConfTarget(t *testing.T) { + t.Parallel() + + // When the current block height is 100 and deadline height is 200, the + // conf target should be 100. + conf := calcCurrentConfTarget(int32(100), int32(200)) + require.EqualValues(t, 100, conf) + + // When the current block height is 200 and deadline height is 100, the + // conf target should be 1 since the deadline has passed. + conf = calcCurrentConfTarget(int32(200), int32(100)) + require.EqualValues(t, 1, conf) +} + +// TestInitializeFeeFunction tests the initialization of the fee function. +func TestInitializeFeeFunction(t *testing.T) { + t.Parallel() + + // Create a test input. + inp := createTestInput(100, input.WitnessKeyHash) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + defer estimator.AssertExpectations(t) + + // Create a publisher using the mocks. + tp := NewTxPublisher(TxPublisherConfig{ + Estimator: estimator, + }) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate, + } + + // Mock the fee estimator to return an error. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + dummyErr := fmt.Errorf("dummy error") + estimator.On("EstimateFeePerKW", mock.Anything).Return( + chainfee.SatPerKWeight(0), dummyErr).Once() + + // Call the method under test and assert the error is returned. + f, err := tp.initializeFeeFunction(req) + require.ErrorIs(t, err, dummyErr) + require.Nil(t, f) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Call the method under test. + f, err = tp.initializeFeeFunction(req) + require.NoError(t, err) + require.Equal(t, feerate, f.FeeRate()) +} + +// TestStoreRecord correctly increases the request counter and saves the +// record. +func TestStoreRecord(t *testing.T) { + t.Parallel() + + // Create a test input. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + } + + // Create a naive fee function. + feeFunc := &LinearFeeFunction{} + + // Create a test fee and tx. + fee := btcutil.Amount(1000) + tx := &wire.MsgTx{} + + // Create a publisher using the mocks. + tp := NewTxPublisher(TxPublisherConfig{}) + + // Get the current counter and check it's increased later. + initialCounter := tp.requestCounter.Load() + + // Call the method under test. + requestID := tp.storeRecord(tx, req, feeFunc, fee) + + // Check the request ID is as expected. + require.Equal(t, initialCounter+1, requestID) + + // Read the saved record and compare. + record, ok := tp.records.Load(requestID) + require.True(t, ok) + require.Equal(t, tx, record.tx) + require.Equal(t, feeFunc, record.feeFunction) + require.Equal(t, fee, record.fee) + require.Equal(t, req, record.req) +} + +// mockers wraps a list of mocked interfaces used inside tx publisher. +type mockers struct { + signer *input.MockInputSigner + wallet *MockWallet + estimator *chainfee.MockEstimator + notifier *chainntnfs.MockChainNotifier + + feeFunc *MockFeeFunction +} + +// createTestPublisher creates a new tx publisher using the provided mockers. +func createTestPublisher(t *testing.T) (*TxPublisher, *mockers) { + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create a mock fee function. + feeFunc := &MockFeeFunction{} + + // Create a mock signer. + signer := &input.MockInputSigner{} + + // Create a mock wallet. + wallet := &MockWallet{} + + // Create a mock chain notifier. + notifier := &chainntnfs.MockChainNotifier{} + + t.Cleanup(func() { + estimator.AssertExpectations(t) + feeFunc.AssertExpectations(t) + signer.AssertExpectations(t) + wallet.AssertExpectations(t) + notifier.AssertExpectations(t) + }) + + m := &mockers{ + signer: signer, + wallet: wallet, + estimator: estimator, + notifier: notifier, + feeFunc: feeFunc, + } + + // Create a publisher using the mocks. + tp := NewTxPublisher(TxPublisherConfig{ + Estimator: m.estimator, + Signer: m.signer, + Wallet: m.wallet, + Notifier: m.notifier, + }) + + return tp, m +} + +// TestCreateAndCheckTx checks `createAndCheckTx` behaves as expected. +func TestCreateAndCheckTx(t *testing.T) { + t.Parallel() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Mock the wallet to fail on testmempoolaccept on the first call, and + // succeed on the second. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + testCases := []struct { + name string + req *BumpRequest + expectedErr error + }{ + { + // When the budget cannot cover the fee, an error + // should be returned. + name: "not enough budget", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + }, + expectedErr: ErrNotEnoughBudget, + }, + { + // When the mempool rejects the transaction, an error + // should be returned. + name: "testmempoolaccept fail", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + }, + expectedErr: errDummy, + }, + { + // When the mempool accepts the transaction, no error + // should be returned. + name: "testmempoolaccept pass", + req: &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + }, + expectedErr: nil, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Call the method under test. + _, _, err := tp.createAndCheckTx(tc.req, m.feeFunc) + + // Check the result is as expected. + require.ErrorIs(t, err, tc.expectedErr) + }) + } +} + +// createTestBumpRequest creates a new bump request. +func createTestBumpRequest() *BumpRequest { + // Create a test input. + inp := createTestInput(1000, input.WitnessKeyHash) + + return &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + } +} + +// TestCreateRBFCompliantTx checks that `createRBFCompliantTx` behaves as +// expected. +func TestCreateRBFCompliantTx(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + testCases := []struct { + name string + setupMock func() + expectedErr error + }{ + { + // When testmempoolaccept accepts the tx, no error + // should be returned. + name: "success case", + setupMock: func() { + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + { + // When testmempoolaccept fails due to a non-fee + // related error, an error should be returned. + name: "non-fee related testmempoolaccept fail", + setupMock: func() { + // Mock the testmempoolaccept to fail. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + }, + expectedErr: errDummy, + }, + { + // When increase feerate gives an error, the error + // should be returned. + name: "fail on increase fee", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + lnwallet.ErrMempoolFee).Once() + + // Mock the fee function to return an error. + m.feeFunc.On("Increment").Return( + false, errDummy).Once() + }, + expectedErr: errDummy, + }, + { + // Test that after one round of increasing the feerate + // the tx passes testmempoolaccept. + name: "increase fee and success on min mempool fee", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee + // for the first call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + lnwallet.ErrMempoolFee).Once() + + // Mock the fee function to increase feerate. + m.feeFunc.On("Increment").Return( + true, nil).Once() + + // Mock the testmempoolaccept to pass on the + // second call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + { + // Test that after one round of increasing the feerate + // the tx passes testmempoolaccept. + name: "increase fee and success on insufficient fee", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee + // for the first call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + rpcclient.ErrInsufficientFee).Once() + + // Mock the fee function to increase feerate. + m.feeFunc.On("Increment").Return( + true, nil).Once() + + // Mock the testmempoolaccept to pass on the + // second call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + { + // Test that the fee function increases the fee rate + // after one round. + name: "increase fee on second round", + setupMock: func() { + // Mock the testmempoolaccept to fail on fee + // for the first call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return( + rpcclient.ErrInsufficientFee).Once() + + // Mock the fee function to NOT increase + // feerate on the first round. + m.feeFunc.On("Increment").Return( + false, nil).Once() + + // Mock the fee function to increase feerate. + m.feeFunc.On("Increment").Return( + true, nil).Once() + + // Mock the testmempoolaccept to pass on the + // second call. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + tc.setupMock() + + // Call the method under test. + id, err := tp.createRBFCompliantTx(req, m.feeFunc) + + // Check the result is as expected. + require.ErrorIs(t, err, tc.expectedErr) + + // If there's an error, expect the requestID to be + // empty. + if tc.expectedErr != nil { + require.Zero(t, id) + } + }) + } +} + +// TestTxPublisherBroadcast checks the internal `broadcast` method behaves as +// expected. +func TestTxPublisherBroadcast(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + + // Quickly check when the requestID cannot be found, an error is + // returned. + result, err := tp.broadcast(uint64(1000)) + require.Error(t, err) + require.Nil(t, result) + + testCases := []struct { + name string + setupMock func() + expectedErr error + expectedResult *BumpResult + }{ + { + // When the wallet cannot publish this tx, the error + // should be put inside the result. + name: "fail to publish", + setupMock: func() { + // Mock the wallet to fail to publish. + m.wallet.On("PublishTransaction", + tx, mock.Anything).Return( + errDummy).Once() + }, + expectedErr: nil, + expectedResult: &BumpResult{ + Event: TxFailed, + Tx: tx, + Fee: fee, + FeeRate: feerate, + Err: errDummy, + requestID: requestID, + }, + }, + { + // When nothing goes wrong, the result is returned. + name: "publish success", + setupMock: func() { + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + tx, mock.Anything).Return(nil).Once() + }, + expectedErr: nil, + expectedResult: &BumpResult{ + Event: TxPublished, + Tx: tx, + Fee: fee, + FeeRate: feerate, + Err: nil, + requestID: requestID, + }, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + tc.setupMock() + + // Call the method under test. + result, err := tp.broadcast(requestID) + + // Check the result is as expected. + require.ErrorIs(t, err, tc.expectedErr) + require.Equal(t, tc.expectedResult, result) + }) + } +} + +// TestRemoveResult checks the records and subscriptions are removed when a tx +// is confirmed or failed. +func TestRemoveResult(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + + testCases := []struct { + name string + setupRecord func() uint64 + result *BumpResult + removed bool + }{ + { + // When the tx is confirmed, the records will be + // removed. + name: "remove on TxConfirmed", + setupRecord: func() uint64 { + id := tp.storeRecord(tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(id, nil) + + return id + }, + result: &BumpResult{ + Event: TxConfirmed, + Tx: tx, + }, + removed: true, + }, + { + // When the tx is failed, the records will be removed. + name: "remove on TxFailed", + setupRecord: func() uint64 { + id := tp.storeRecord(tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(id, nil) + + return id + }, + result: &BumpResult{ + Event: TxFailed, + Err: errDummy, + Tx: tx, + }, + removed: true, + }, + { + // Noop when the tx is neither confirmed or failed. + name: "noop when tx is not confirmed or failed", + setupRecord: func() uint64 { + id := tp.storeRecord(tx, req, m.feeFunc, fee) + tp.subscriberChans.Store(id, nil) + + return id + }, + result: &BumpResult{ + Event: TxPublished, + Tx: tx, + }, + removed: false, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + requestID := tc.setupRecord() + + // Attach the requestID from the setup. + tc.result.requestID = requestID + + // Remove the result. + tp.removeResult(tc.result) + + // Check if the record is removed. + _, found := tp.records.Load(requestID) + require.Equal(t, !tc.removed, found) + + _, found = tp.subscriberChans.Load(requestID) + require.Equal(t, !tc.removed, found) + }) + } +} + +// TestNotifyResult checks the subscribers are notified when a result is sent. +func TestNotifyResult(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Create a test result. + result := &BumpResult{ + requestID: requestID, + Tx: tx, + } + + // Notify the result and expect the subscriber to receive it. + // + // NOTE: must be done inside a goroutine in case it blocks. + go tp.notifyResult(result) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case received := <-subscriber: + require.Equal(t, result, received) + } + + // Notify two results. This time it should block because the channel is + // full. We then shutdown TxPublisher to test the quit behavior. + done := make(chan struct{}) + go func() { + // Call notifyResult twice, which blocks at the second call. + tp.notifyResult(result) + tp.notifyResult(result) + + close(done) + }() + + // Shutdown the publisher and expect notifyResult to exit. + close(tp.quit) + + // We expect to done chan. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for notifyResult to exit") + + case <-done: + } +} + +// TestBroadcastSuccess checks the public `Broadcast` method can successfully +// broadcast a tx based on the request. +func TestBroadcastSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Once() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to pass. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate, + } + + // Send the req and expect no error. + resultChan, err := tp.Broadcast(req) + require.NoError(t, err) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the first result to be TxPublished. + require.Equal(t, TxPublished, result.Event) + } + + // Validate the record was stored. + require.Equal(t, 1, tp.records.Len()) + require.Equal(t, 1, tp.subscriberChans.Len()) +} + +// TestBroadcastFail checks the public `Broadcast` returns the error or a +// failed result when the broadcast fails. +func TestBroadcastFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test feerate. + feerate := chainfee.SatPerKWeight(1000) + + // Create a test request. + inp := createTestInput(1000, input.WitnessKeyHash) + + // Create a testing bump request. + req := &BumpRequest{ + DeliveryAddress: changePkScript, + Inputs: []input.Input{&inp}, + Budget: btcutil.Amount(1000), + MaxFeeRate: feerate, + } + + // Mock the fee estimator to return the testing fee rate. + // + // We are not testing `NewLinearFeeFunction` here, so the actual params + // used are irrelevant. + m.estimator.On("EstimateFeePerKW", mock.Anything).Return( + feerate, nil).Twice() + m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice() + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(errDummy).Once() + + // Send the req and expect an error returned. + resultChan, err := tp.Broadcast(req) + require.ErrorIs(t, err, errDummy) + require.Nil(t, resultChan) + + // Validate the record was NOT stored. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) + + // Mock the testmempoolaccept again, this time it passes. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to fail on publish. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Send the req and expect no error returned. + resultChan, err = tp.Broadcast(req) + require.NoError(t, err) + + // Check the result is sent back. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-resultChan: + // We expect the result to be TxFailed and the error is set in + // the result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + } + + // Validate the record was removed. + require.Equal(t, 0, tp.records.Len()) + require.Equal(t, 0, tp.subscriberChans.Len()) +} + +// TestCreateAnPublishFail checks all the error cases are handled properly in +// the method createAndPublish. +func TestCreateAnPublishFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test requestID. + requestID := uint64(1) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a testing monitor record. + req := createTestBumpRequest() + + // Overwrite the budget to make it smaller than the fee. + req.Budget = 100 + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: &wire.MsgTx{}, + } + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Call the createAndPublish method. + resultOpt := tp.createAndPublishTx(requestID, record) + result := resultOpt.UnwrapOrFail(t) + + // We expect the result to be TxFailed and the error is set in the + // result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, ErrNotEnoughBudget) + require.Equal(t, requestID, result.requestID) + + // Increase the budget and call it again. This time we will mock an + // error to be returned from CheckMempoolAcceptance. + req.Budget = 1000 + + // Mock the testmempoolaccept to return a fee related error that should + // be ignored. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(lnwallet.ErrMempoolFee).Once() + + // Call the createAndPublish method and expect a none option. + resultOpt = tp.createAndPublishTx(requestID, record) + require.True(t, resultOpt.IsNone()) + + // Mock the testmempoolaccept to return a fee related error that should + // be ignored. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(rpcclient.ErrInsufficientFee).Once() + + // Call the createAndPublish method and expect a none option. + resultOpt = tp.createAndPublishTx(requestID, record) + require.True(t, resultOpt.IsNone()) +} + +// TestCreateAnPublishSuccess checks the expected result is returned from the +// method createAndPublish. +func TestCreateAnPublishSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test requestID. + requestID := uint64(1) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a testing monitor record. + req := createTestBumpRequest() + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: &wire.MsgTx{}, + } + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil) + + // Mock the wallet to publish and return an error. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Call the createAndPublish method and expect a failure result. + resultOpt := tp.createAndPublishTx(requestID, record) + result := resultOpt.UnwrapOrFail(t) + + // We expect the result to be TxFailed and the error is set. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + + // Although the replacement tx was failed to be published, the record + // should be stored. + require.NotNil(t, result.Tx) + require.NotNil(t, result.ReplacedTx) + _, found := tp.records.Load(requestID) + require.True(t, found) + + // We now check a successful RBF. + // + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call the createAndPublish method and expect a success result. + resultOpt = tp.createAndPublishTx(requestID, record) + result = resultOpt.UnwrapOrFail(t) + require.True(t, resultOpt.IsSome()) + + // We expect the result to be TxReplaced and the error is nil. + require.Equal(t, TxReplaced, result.Event) + require.Nil(t, result.Err) + + // Check the Tx and ReplacedTx are set. + require.NotNil(t, result.Tx) + require.NotNil(t, result.ReplacedTx) + + // Check the record is stored. + _, found = tp.records.Load(requestID) + require.True(t, found) +} + +// TestHandleTxConfirmed checks the expected result is returned from the method +// handleTxConfirmed. +func TestHandleTxConfirmed(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + record, ok := tp.records.Load(requestID) + require.True(t, ok) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Mock the fee function to return a fee rate. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate).Once() + + // Call the method and expect a result to be received. + // + // NOTE: must be called in a goroutine in case it blocks. + tp.wg.Add(1) + go tp.handleTxConfirmed(record, requestID) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-subscriber: + // We expect the result to be TxConfirmed and the tx is set. + require.Equal(t, TxConfirmed, result.Event) + require.Equal(t, tx, result.Tx) + require.Nil(t, result.Err) + require.Equal(t, requestID, result.requestID) + require.Equal(t, record.fee, result.Fee) + require.Equal(t, feerate, result.FeeRate) + } + + // We expect the record to be removed from the maps. + _, found := tp.records.Load(requestID) + require.False(t, found) + _, found = tp.subscriberChans.Load(requestID) + require.False(t, found) +} + +// TestHandleFeeBumpTx validates handleFeeBumpTx behaves as expected. +func TestHandleFeeBumpTx(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a test current height. + testHeight := int32(800000) + + // Create a testing monitor record. + req := createTestBumpRequest() + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: tx, + } + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Mock the fee function to skip the bump due to error. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return( + false, errDummy).Once() + + // Call the method and expect no result received. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + // Check there's no result sent back. + select { + case <-time.After(time.Second): + case result := <-subscriber: + t.Fatalf("unexpected result received: %v", result) + } + + // Mock the fee function to skip the bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(false, nil).Once() + + // Call the method and expect no result received. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + // Check there's no result sent back. + select { + case <-time.After(time.Second): + case result := <-subscriber: + t.Fatalf("unexpected result received: %v", result) + } + + // Mock the fee function to perform the fee bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(true, nil) + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil) + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call the method and expect a result to be received. + // + // NOTE: must be called in a goroutine in case it blocks. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-subscriber: + // We expect the result to be TxReplaced. + require.Equal(t, TxReplaced, result.Event) + + // The new tx and old tx should be properly set. + require.NotEqual(t, tx, result.Tx) + require.Equal(t, tx, result.ReplacedTx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID, result.requestID) + } + + // We expect the record to NOT be removed from the maps. + _, found := tp.records.Load(requestID) + require.True(t, found) + _, found = tp.subscriberChans.Load(requestID) + require.True(t, found) +} + +// TestProcessRecords validates processRecords behaves as expected. +func TestProcessRecords(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create testing objects. + requestID1 := uint64(1) + req1 := createTestBumpRequest() + tx1 := &wire.MsgTx{LockTime: 1} + txid1 := tx1.TxHash() + + requestID2 := uint64(2) + req2 := createTestBumpRequest() + tx2 := &wire.MsgTx{LockTime: 2} + txid2 := tx2.TxHash() + + // Create a monitor record that's confirmed. + recordConfirmed := &monitorRecord{ + req: req1, + feeFunction: m.feeFunc, + tx: tx1, + } + m.wallet.On("GetTransactionDetails", &txid1).Return( + &lnwallet.TransactionDetail{ + NumConfirmations: 1, + }, nil, + ).Once() + + // Create a monitor record that's not confirmed. We know it's not + // confirmed because the num of confirms is zero. + recordFeeBump := &monitorRecord{ + req: req2, + feeFunction: m.feeFunc, + tx: tx2, + } + m.wallet.On("GetTransactionDetails", &txid2).Return( + &lnwallet.TransactionDetail{ + NumConfirmations: 0, + }, nil, + ).Once() + + // Setup the initial publisher state by adding the records to the maps. + subscriberConfirmed := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID1, subscriberConfirmed) + tp.records.Store(requestID1, recordConfirmed) + + subscriberReplaced := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID2, subscriberReplaced) + tp.records.Store(requestID2, recordFeeBump) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // The following methods should only be called once when creating the + // replacement tx. + // + // Mock the fee function to NOT skip the fee bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(true, nil).Once() + + // Mock the signer to always return a valid script. + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(&input.Script{}, nil).Once() + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call processRecords and expect the results are notified back. + tp.processRecords() + + // We expect two results to be received. One for the confirmed tx and + // one for the replaced tx. + // + // Check the confirmed tx result. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriberConfirmed") + + case result := <-subscriberConfirmed: + // We expect the result to be TxConfirmed. + require.Equal(t, TxConfirmed, result.Event) + require.Equal(t, tx1, result.Tx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID1, result.requestID) + } + + // Now check the replaced tx result. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriberReplaced") + + case result := <-subscriberReplaced: + // We expect the result to be TxReplaced. + require.Equal(t, TxReplaced, result.Event) + + // The new tx and old tx should be properly set. + require.NotEqual(t, tx2, result.Tx) + require.Equal(t, tx2, result.ReplacedTx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID2, result.requestID) + } +} diff --git a/sweep/fee_function.go b/sweep/fee_function.go new file mode 100644 index 0000000000..955ca43a6c --- /dev/null +++ b/sweep/fee_function.go @@ -0,0 +1,288 @@ +package sweep + +import ( + "errors" + "fmt" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrMaxPosition is returned when trying to increase the position of + // the fee function while it's already at its max. + ErrMaxPosition = errors.New("position already at max") +) + +// mSatPerKWeight represents a fee rate in msat/kw. +// +// TODO(yy): unify all the units to be virtual bytes. +type mSatPerKWeight lnwire.MilliSatoshi + +// String returns a human-readable string of the fee rate. +func (m mSatPerKWeight) String() string { + s := lnwire.MilliSatoshi(m) + return fmt.Sprintf("%v/kw", s) +} + +// FeeFunction defines an interface that is used to calculate fee rates for +// transactions. It's expected the implementations use three params, the +// starting fee rate, the ending fee rate, and number of blocks till deadline +// block height, to build an algorithm to calculate the fee rate based on the +// current block height. +type FeeFunction interface { + // FeeRate returns the current fee rate calculated by the fee function. + FeeRate() chainfee.SatPerKWeight + + // Increment increases the fee rate by one step. The definition of one + // step is up to the implementation. After calling this method, it's + // expected to change the state of the fee function such that calling + // `FeeRate` again will return the increased value. + // + // It returns a boolean to indicate whether the fee rate is increased, + // as fee bump should not be attempted if the increased fee rate is not + // greater than the current fee rate, which may happen if the algorithm + // gives the same fee rates at two positions. + // + // An error is returned when the max fee rate is reached. + // + // NOTE: we intentionally don't return the new fee rate here, so both + // the implementation and the caller are aware of the state change. + Increment() (bool, error) + + // IncreaseFeeRate increases the fee rate to the new position + // calculated using (width - confTarget). It returns a boolean to + // indicate whether the fee rate is increased, and an error if the + // position is greater than the width. + // + // NOTE: this method is provided to allow the caller to increase the + // fee rate based on a conf target without taking care of the fee + // function's current state (position). + IncreaseFeeRate(confTarget uint32) (bool, error) +} + +// LinearFeeFunction implements the FeeFunction interface with a linear +// function: +// +// feeRate = startingFeeRate + position * delta. +// - width: deadlineBlockHeight - startingBlockHeight +// - delta: (endingFeeRate - startingFeeRate) / width +// - position: currentBlockHeight - startingBlockHeight +// +// The fee rate will be capped at endingFeeRate. +// +// TODO(yy): implement more functions specified here: +// - https://github.com/lightningnetwork/lnd/issues/4215 +type LinearFeeFunction struct { + // startingFeeRate specifies the initial fee rate to begin with. + startingFeeRate chainfee.SatPerKWeight + + // endingFeeRate specifies the max allowed fee rate. + endingFeeRate chainfee.SatPerKWeight + + // currentFeeRate specifies the current calculated fee rate. + currentFeeRate chainfee.SatPerKWeight + + // width is the number of blocks between the starting block height + // and the deadline block height. + width uint32 + + // position is the number of blocks between the starting block height + // and the current block height. + position uint32 + + // deltaFeeRate is the fee rate (msat/kw) increase per block. + // + // NOTE: this is used to increase precision. + deltaFeeRate mSatPerKWeight + + // estimator is the fee estimator used to estimate the fee rate. We use + // it to get the initial fee rate and, use it as a benchmark to decide + // whether we want to used the estimated fee rate or the calculated fee + // rate based on different strategies. + estimator chainfee.Estimator +} + +// Compile-time check to ensure LinearFeeFunction satisfies the FeeFunction. +var _ FeeFunction = (*LinearFeeFunction)(nil) + +// NewLinearFeeFunction creates a new linear fee function and initializes it +// with a starting fee rate which is an estimated value returned from the fee +// estimator using the initial conf target. +func NewLinearFeeFunction(maxFeeRate chainfee.SatPerKWeight, confTarget uint32, + estimator chainfee.Estimator) (*LinearFeeFunction, error) { + + // Sanity check conf target. + if confTarget == 0 { + return nil, fmt.Errorf("width must be greater than zero") + } + + l := &LinearFeeFunction{ + endingFeeRate: maxFeeRate, + width: confTarget, + estimator: estimator, + } + + // Estimate the initial fee rate. + // + // NOTE: estimateFeeRate guarantees the returned fee rate is capped by + // the ending fee rate, so we don't need to worry about overpay. + start, err := l.estimateFeeRate(confTarget) + if err != nil { + return nil, fmt.Errorf("estimate initial fee rate: %w", err) + } + + // Calculate how much fee rate should be increased per block. + end := l.endingFeeRate + + // The starting and ending fee rates are in sat/kw, so we need to + // convert them to msat/kw by multiplying by 1000. + delta := btcutil.Amount(end - start).MulF64(1000 / float64(confTarget)) + l.deltaFeeRate = mSatPerKWeight(delta) + + // We only allow the delta to be zero if the width is one - when the + // delta is zero, it means the starting and ending fee rates are the + // same, which means there's nothing to increase, so any width greater + // than 1 doesn't provide any utility. This could happen when the + // sweeper is offered to sweep an input that has passed its deadline. + if l.deltaFeeRate == 0 && l.width != 1 { + log.Errorf("Failed to init fee function: startingFeeRate=%v, "+ + "endingFeeRate=%v, width=%v, delta=%v", start, end, + confTarget, l.deltaFeeRate) + + return nil, fmt.Errorf("fee rate delta is zero") + } + + // Attach the calculated values to the fee function. + l.startingFeeRate = start + l.currentFeeRate = start + + log.Debugf("Linear fee function initialized with startingFeeRate=%v, "+ + "endingFeeRate=%v, width=%v, delta=%v", start, end, + confTarget, l.deltaFeeRate) + + return l, nil +} + +// FeeRate returns the current fee rate. +// +// NOTE: part of the FeeFunction interface. +func (l *LinearFeeFunction) FeeRate() chainfee.SatPerKWeight { + return l.currentFeeRate +} + +// Increment increases the fee rate by one position, returns a boolean to +// indicate whether the fee rate was increased, and an error if the position is +// greater than the width. The increased fee rate will be set as the current +// fee rate, and the internal position will be incremented. +// +// NOTE: this method will change the state of the fee function as it increases +// its current fee rate. +// +// NOTE: part of the FeeFunction interface. +func (l *LinearFeeFunction) Increment() (bool, error) { + return l.increaseFeeRate(l.position + 1) +} + +// IncreaseFeeRate calculate a new position using the given conf target, and +// increases the fee rate to the new position by calling the Increment method. +// +// NOTE: this method will change the state of the fee function as it increases +// its current fee rate. +// +// NOTE: part of the FeeFunction interface. +func (l *LinearFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) { + // If the new position is already at the end, we return an error. + if confTarget == 0 { + return false, ErrMaxPosition + } + + newPosition := uint32(0) + + // Only calculate the new position when the conf target is less than + // the function's width - the width is the initial conf target, and we + // expect the current conf target to decrease over time. However, we + // still allow the supplied conf target to be greater than the width, + // and we won't increase the fee rate in that case. + if confTarget < l.width { + newPosition = l.width - confTarget + log.Tracef("Increasing position from %v to %v", l.position, + newPosition) + } + + if newPosition <= l.position { + log.Tracef("Skipped increase feerate: position=%v, "+ + "newPosition=%v ", l.position, newPosition) + + return false, nil + } + + return l.increaseFeeRate(newPosition) +} + +// increaseFeeRate increases the fee rate by the specified position, returns a +// boolean to indicate whether the fee rate was increased, and an error if the +// position is greater than the width. The increased fee rate will be set as +// the current fee rate, and the internal position will be set to the specified +// position. +// +// NOTE: this method will change the state of the fee function as it increases +// its current fee rate. +func (l *LinearFeeFunction) increaseFeeRate(position uint32) (bool, error) { + // If the new position is already at the end, we return an error. + if l.position >= l.width { + return false, ErrMaxPosition + } + + // Get the old fee rate. + oldFeeRate := l.currentFeeRate + + // Update its internal state. + l.position = position + l.currentFeeRate = l.feeRateAtPosition(position) + + log.Tracef("Fee rate increased from %v to %v at position %v", + oldFeeRate, l.currentFeeRate, l.position) + + return l.currentFeeRate > oldFeeRate, nil +} + +// feeRateAtPosition calculates the fee rate at a given position and caps it at +// the ending fee rate. +func (l *LinearFeeFunction) feeRateAtPosition(p uint32) chainfee.SatPerKWeight { + if p >= l.width { + return l.endingFeeRate + } + + // deltaFeeRate is in msat/kw, so we need to divide by 1000 to get the + // fee rate in sat/kw. + feeRateDelta := btcutil.Amount(l.deltaFeeRate).MulF64(float64(p) / 1000) + + feeRate := l.startingFeeRate + chainfee.SatPerKWeight(feeRateDelta) + if feeRate > l.endingFeeRate { + return l.endingFeeRate + } + + return feeRate +} + +// estimateFeeRate asks the fee estimator to estimate the fee rate based on its +// conf target. +func (l *LinearFeeFunction) estimateFeeRate( + confTarget uint32) (chainfee.SatPerKWeight, error) { + + fee := FeeEstimateInfo{ + ConfTarget: confTarget, + } + + // endingFeeRate comes from budget/txWeight, which means the returned + // fee rate will always be capped by this value, hence we don't need to + // worry about overpay. + estimatedFeeRate, err := fee.Estimate(l.estimator, l.endingFeeRate) + if err != nil { + return 0, err + } + + return estimatedFeeRate, nil +} diff --git a/sweep/fee_function_test.go b/sweep/fee_function_test.go new file mode 100644 index 0000000000..e549d8d643 --- /dev/null +++ b/sweep/fee_function_test.go @@ -0,0 +1,247 @@ +package sweep + +import ( + "testing" + + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +// TestLinearFeeFunctionNew tests the NewLinearFeeFunction function. +func TestLinearFeeFunctionNew(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create testing params. + maxFeeRate := chainfee.SatPerKWeight(10000) + estimatedFeeRate := chainfee.SatPerKWeight(500) + confTarget := uint32(6) + + // Assert init fee function with zero conf value returns an error. + f, err := NewLinearFeeFunction(maxFeeRate, 0, estimator) + rt.ErrorContains(err, "width must be greater than zero") + rt.Nil(f) + + // When the fee estimator returns an error, it's returned. + // + // Mock the fee estimator to return an error. + estimator.On("EstimateFeePerKW", confTarget).Return( + chainfee.SatPerKWeight(0), errDummy).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.ErrorIs(err, errDummy) + rt.Nil(f) + + // When the starting feerate is greater than the ending feerate, the + // starting feerate is capped. + // + // Mock the fee estimator to return the fee rate. + smallConf := uint32(1) + estimator.On("EstimateFeePerKW", smallConf).Return( + // The fee rate is greater than the max fee rate. + maxFeeRate+1, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, smallConf, estimator) + rt.NoError(err) + rt.NotNil(f) + + // When the calculated fee rate delta is 0, an error should be returned. + // + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + // The starting fee rate is the max fee rate. + maxFeeRate, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.ErrorContains(err, "fee rate delta is zero") + rt.Nil(f) + + // Check a successfully created fee function. + // + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + estimatedFeeRate, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err = NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.NoError(err) + rt.NotNil(f) + + // Assert the internal state. + rt.Equal(estimatedFeeRate, f.startingFeeRate) + rt.Equal(maxFeeRate, f.endingFeeRate) + rt.Equal(estimatedFeeRate, f.currentFeeRate) + rt.NotZero(f.deltaFeeRate) + rt.Equal(confTarget, f.width) +} + +// TestLinearFeeFunctionFeeRateAtPosition checks the expected feerate is +// calculated and returned. +func TestLinearFeeFunctionFeeRateAtPosition(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a fee func which has three positions: + // - position 0: 1000 + // - position 1: 2000 + // - position 2: 3000 + f := &LinearFeeFunction{ + startingFeeRate: 1000, + endingFeeRate: 3000, + position: 0, + deltaFeeRate: 1_000_000, + width: 3, + } + + testCases := []struct { + name string + pos uint32 + expectedFeerate chainfee.SatPerKWeight + }{ + { + name: "position 0", + pos: 0, + expectedFeerate: 1000, + }, + { + name: "position 1", + pos: 1, + expectedFeerate: 2000, + }, + { + name: "position 2", + pos: 2, + expectedFeerate: 3000, + }, + { + name: "position 3", + pos: 3, + expectedFeerate: 3000, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := f.feeRateAtPosition(tc.pos) + rt.Equal(tc.expectedFeerate, result) + }) + } +} + +// TestLinearFeeFunctionIncrement checks the internal state is updated +// correctly when the fee rate is incremented. +func TestLinearFeeFunctionIncrement(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create testing params. These params are chosen so the delta value is + // 100. + maxFeeRate := chainfee.SatPerKWeight(1000) + estimatedFeeRate := chainfee.SatPerKWeight(100) + confTarget := uint32(9) + + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + estimatedFeeRate, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err := NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.NoError(err) + + // We now increase the position from 1 to 9. + for i := uint32(1); i <= confTarget; i++ { + // Increase the fee rate. + increased, err := f.Increment() + rt.NoError(err) + rt.True(increased) + + // Assert the internal state. + rt.Equal(i, f.position) + + delta := chainfee.SatPerKWeight(i * 100) + rt.Equal(estimatedFeeRate+delta, f.currentFeeRate) + + // Check public method returns the expected fee rate. + rt.Equal(estimatedFeeRate+delta, f.FeeRate()) + } + + // Now the position is at 9th, increase it again should give us an + // error. + increased, err := f.Increment() + rt.ErrorIs(err, ErrMaxPosition) + rt.False(increased) +} + +// TestLinearFeeFunctionIncreaseFeeRate checks the internal state is updated +// correctly when the fee rate is increased using conf targets. +func TestLinearFeeFunctionIncreaseFeeRate(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} + + // Create testing params. These params are chosen so the delta value is + // 100. + maxFeeRate := chainfee.SatPerKWeight(1000) + estimatedFeeRate := chainfee.SatPerKWeight(100) + confTarget := uint32(9) + + // Mock the fee estimator to return the fee rate. + estimator.On("EstimateFeePerKW", confTarget).Return( + estimatedFeeRate, nil).Once() + estimator.On("RelayFeePerKW").Return(estimatedFeeRate).Once() + + f, err := NewLinearFeeFunction(maxFeeRate, confTarget, estimator) + rt.NoError(err) + + // If we are increasing the fee rate using the initial conf target, we + // should get a nil error and false. + increased, err := f.IncreaseFeeRate(confTarget) + rt.NoError(err) + rt.False(increased) + + // Test that we are allowed to use a larger conf target. + increased, err = f.IncreaseFeeRate(confTarget + 1) + rt.NoError(err) + rt.False(increased) + + // Test that when we use a conf target of 0, we get an error. + increased, err = f.IncreaseFeeRate(0) + rt.ErrorIs(err, ErrMaxPosition) + rt.False(increased) + + // We now increase the fee rate from conf target 8 to 1 and assert we + // get no error and true. + for i := uint32(1); i < confTarget; i++ { + // Increase the fee rate. + increased, err := f.IncreaseFeeRate(confTarget - i) + rt.NoError(err) + rt.True(increased) + + // Assert the internal state. + rt.Equal(i, f.position) + + delta := chainfee.SatPerKWeight(i * 100) + rt.Equal(estimatedFeeRate+delta, f.currentFeeRate) + + // Check public method returns the expected fee rate. + rt.Equal(estimatedFeeRate+delta, f.FeeRate()) + } +} diff --git a/sweep/interface.go b/sweep/interface.go index a9de8bc570..a6e5d21537 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -41,4 +41,14 @@ type Wallet interface { // used to ensure that invalid transactions (inputs spent) aren't // retried in the background. CancelRebroadcast(tx chainhash.Hash) + + // CheckMempoolAcceptance checks whether a transaction follows mempool + // policies and returns an error if it cannot be accepted into the + // mempool. + CheckMempoolAcceptance(tx *wire.MsgTx) error + + // GetTransactionDetails returns a detailed description of a tx given + // its transaction hash. + GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) } diff --git a/sweep/mock_test.go b/sweep/mock_test.go index fc7ff9c34d..6b23953c3a 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -5,8 +5,10 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -44,6 +46,10 @@ func newMockBackend(t *testing.T, notifier *MockNotifier) *mockBackend { } } +func (b *mockBackend) CheckMempoolAcceptance(tx *wire.MsgTx) error { + return nil +} + func (b *mockBackend) publishTransaction(tx *wire.MsgTx) error { b.lock.Lock() defer b.lock.Unlock() @@ -169,6 +175,14 @@ func (b *mockBackend) FetchTx(chainhash.Hash) (*wire.MsgTx, error) { func (b *mockBackend) CancelRebroadcast(tx chainhash.Hash) { } +// GetTransactionDetails returns a detailed description of a tx given its +// transaction hash. +func (b *mockBackend) GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) { + + return nil, nil +} + // mockFeeEstimator implements a mock fee estimator. It closely resembles // lnwallet.StaticFeeEstimator with the addition that fees can be changed for // testing purposes in a thread safe manner. @@ -342,6 +356,14 @@ type MockWallet struct { // Compile-time constraint to ensure MockWallet implements Wallet. var _ Wallet = (*MockWallet)(nil) +// CheckMempoolAcceptance checks if the transaction can be accepted to the +// mempool. +func (m *MockWallet) CheckMempoolAcceptance(tx *wire.MsgTx) error { + args := m.Called(tx) + + return args.Error(0) +} + // PublishTransaction performs cursory validation (dust checks, etc) and // broadcasts the passed transaction to the Bitcoin network. func (m *MockWallet) PublishTransaction(tx *wire.MsgTx, label string) error { @@ -404,6 +426,20 @@ func (m *MockWallet) CancelRebroadcast(tx chainhash.Hash) { m.Called(tx) } +// GetTransactionDetails returns a detailed description of a tx given its +// transaction hash. +func (m *MockWallet) GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) { + + args := m.Called(txHash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*lnwallet.TransactionDetail), args.Error(1) +} + // MockInputSet is a mock implementation of the InputSet interface. type MockInputSet struct { mock.Mock @@ -446,3 +482,65 @@ func (m *MockInputSet) NeedWalletInput() bool { return args.Bool(0) } + +// DeadlineHeight returns the deadline height for the set. +func (m *MockInputSet) DeadlineHeight() fn.Option[int32] { + args := m.Called() + + return args.Get(0).(fn.Option[int32]) +} + +// Budget givens the total amount that can be used as fees by this input set. +func (m *MockInputSet) Budget() btcutil.Amount { + args := m.Called() + + return args.Get(0).(btcutil.Amount) +} + +// MockBumper is a mock implementation of the interface Bumper. +type MockBumper struct { + mock.Mock +} + +// Compile-time constraint to ensure MockBumper implements Bumper. +var _ Bumper = (*MockBumper)(nil) + +// Broadcast broadcasts the transaction to the network. +func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { + args := m.Called(req) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(chan *BumpResult), args.Error(1) +} + +// MockFeeFunction is a mock implementation of the FeeFunction interface. +type MockFeeFunction struct { + mock.Mock +} + +// Compile-time constraint to ensure MockFeeFunction implements FeeFunction. +var _ FeeFunction = (*MockFeeFunction)(nil) + +// FeeRate returns the current fee rate calculated by the fee function. +func (m *MockFeeFunction) FeeRate() chainfee.SatPerKWeight { + args := m.Called() + + return args.Get(0).(chainfee.SatPerKWeight) +} + +// Increment adds one delta to the current fee rate. +func (m *MockFeeFunction) Increment() (bool, error) { + args := m.Called() + + return args.Bool(0), args.Error(1) +} + +// IncreaseFeeRate increases the fee rate by one step. +func (m *MockFeeFunction) IncreaseFeeRate(confTarget uint32) (bool, error) { + args := m.Called(confTarget) + + return args.Bool(0), args.Error(1) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index bd266aaa0b..505292912a 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -13,7 +13,6 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) @@ -41,6 +40,12 @@ var ( // an input is included in a publish attempt before giving up and // returning an error to the caller. DefaultMaxSweepAttempts = 10 + + // DefaultDeadlineDelta defines a default deadline delta (1 week) to be + // used when sweeping inputs with no deadline pressure. + // + // TODO(yy): make this configurable. + DefaultDeadlineDelta = int32(1008) ) // Params contains the parameters that control the sweeping process. @@ -52,11 +57,22 @@ type Params struct { // Force indicates whether the input should be swept regardless of // whether it is economical to do so. + // + // TODO(yy): Remove this param once deadline based sweeping is in place. Force bool // ExclusiveGroup is an identifier that, if set, prevents other inputs // with the same identifier from being batched together. ExclusiveGroup *uint64 + + // DeadlineHeight specifies an absolute block height that this input + // should be confirmed by. This value is used by the fee bumper to + // decide its urgency and adjust its feerate used. + DeadlineHeight fn.Option[int32] + + // Budget specifies the maximum amount of satoshis that can be spent on + // fees for this sweep. + Budget btcutil.Amount } // ParamsUpdate contains a new set of parameters to update a pending sweep with. @@ -196,6 +212,11 @@ type pendingInput struct { rbf fn.Option[RBFInfo] } +// String returns a human readable interpretation of the pending input. +func (p *pendingInput) String() string { + return fmt.Sprintf("%v (%v)", p.Input.OutPoint(), p.Input.WitnessType()) +} + // parameters returns the sweep parameters for this input. // // NOTE: Part of the txInput interface. @@ -301,6 +322,10 @@ type UtxoSweeper struct { // currentHeight is the best known height of the main chain. This is // updated whenever a new block epoch is received. currentHeight int32 + + // bumpResultChan is a channel that receives broadcast results from the + // TxPublisher. + bumpResultChan chan *BumpResult } // UtxoSweeperConfig contains dependencies of UtxoSweeper. @@ -348,6 +373,10 @@ type UtxoSweeperConfig struct { // Aggregator is used to group inputs into clusters based on its // implemention-specific strategy. Aggregator UtxoAggregator + + // Publisher is used to publish the sweep tx crafted here and monitors + // it for potential fee bumps. + Publisher Bumper } // Result is the struct that is pushed through the result channel. Callers can @@ -381,6 +410,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { pendingSweepsReqs: make(chan *pendingSweepsReq), quit: make(chan struct{}), pendingInputs: make(pendingInputs), + bumpResultChan: make(chan *BumpResult, 100), } } @@ -654,11 +684,16 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { err: err, } - // A new block comes in, update the bestHeight. - // - // TODO(yy): this is where we check our published transactions - // and perform RBF if needed. We'd also like to consult our fee - // bumper to get an updated fee rate. + case result := <-s.bumpResultChan: + // Handle the bump event. + err := s.handleBumpEvent(result) + if err != nil { + log.Errorf("Failed to handle bump event: %v", + err) + } + + // A new block comes in, update the bestHeight, perform a check + // over all pending inputs and publish sweeping txns if needed. case epoch, ok := <-blockEpochs: if !ok { // We should stop the sweeper before stopping @@ -763,8 +798,8 @@ func (s *UtxoSweeper) signalResult(pi *pendingInput, result Result) { } } -// sweep takes a set of preselected inputs, creates a sweep tx and publishes the -// tx. The output address is only marked as used if the publish succeeds. +// sweep takes a set of preselected inputs, creates a sweep tx and publishes +// the tx. The output address is only marked as used if the publish succeeds. func (s *UtxoSweeper) sweep(set InputSet) error { // Generate an output script if there isn't an unused script available. if s.currentOutputScript == nil { @@ -775,83 +810,65 @@ func (s *UtxoSweeper) sweep(set InputSet) error { s.currentOutputScript = pkScript } - // Create sweep tx. - tx, fee, err := createSweepTx( - set.Inputs(), nil, s.currentOutputScript, - uint32(s.currentHeight), set.FeeRate(), - s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer, - ) - if err != nil { - return fmt.Errorf("create sweep tx: %w", err) - } - - tr := &TxRecord{ - Txid: tx.TxHash(), - FeeRate: uint64(set.FeeRate()), - Fee: uint64(fee), + // Create a default deadline height, and replace it with set's + // DeadlineHeight if it's set. + deadlineHeight := s.currentHeight + DefaultDeadlineDelta + deadlineHeight = set.DeadlineHeight().UnwrapOr(deadlineHeight) + + // Create a fee bump request and ask the publisher to broadcast it. The + // publisher will then take over and start monitoring the tx for + // potential fee bump. + req := &BumpRequest{ + Inputs: set.Inputs(), + Budget: set.Budget(), + DeadlineHeight: deadlineHeight, + DeliveryAddress: s.currentOutputScript, + MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(), + // TODO(yy): pass the strategy here. } // Reschedule the inputs that we just tried to sweep. This is done in // case the following publish fails, we'd like to update the inputs' // publish attempts and rescue them in the next sweep. - err = s.markInputsPendingPublish(tr, tx.TxIn) - if err != nil { - return err - } + s.markInputsPendingPublish(set) - log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", - tx.TxHash(), len(tx.TxIn), s.currentHeight) - - // Publish the sweeping tx with customized label. - err = s.cfg.Wallet.PublishTransaction( - tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), - ) + // Broadcast will return a read-only chan that we will listen to for + // this publish result and future RBF attempt. + resp, err := s.cfg.Publisher.Broadcast(req) if err != nil { + outpoints := make([]wire.OutPoint, len(set.Inputs())) + for i, inp := range set.Inputs() { + outpoints[i] = *inp.OutPoint() + } + // TODO(yy): find out which input is causing the failure. - s.markInputsPublishFailed(tx.TxIn) + s.markInputsPublishFailed(outpoints) return err } - // Inputs have been successfully published so we update their states. - err = s.markInputsPublished(tr, tx.TxIn) - if err != nil { - return err - } - - // If there's no error, remove the output script. Otherwise keep it so - // that it can be reused for the next transaction and causes no address - // inflation. - s.currentOutputScript = nil + // Successfully sent the broadcast attempt, we now handle the result by + // subscribing to the result chan and listen for future updates about + // this tx. + s.wg.Add(1) + go s.monitorFeeBumpResult(resp) return nil } -// markInputsPendingPublish saves the sweeping tx to db and updates the pending -// inputs with the given tx inputs. It also increments the `publishAttempts`. -func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, - inputs []*wire.TxIn) error { - - // Add tx to db before publication, so that we will always know that a - // spend by this tx is ours. Otherwise if the publish doesn't return, - // but did publish, we'd lose track of this tx. Even republication on - // startup doesn't prevent this, because that call returns a double - // spend error then and would also not add the hash to the store. - err := s.cfg.Store.StoreTx(tr) - if err != nil { - return fmt.Errorf("store tx: %w", err) - } - +// markInputsPendingPublish updates the pending inputs with the given tx +// inputs. It also increments the `publishAttempts`. +func (s *UtxoSweeper) markInputsPendingPublish(set InputSet) { // Reschedule sweep. - for _, input := range inputs { - pi, ok := s.pendingInputs[input.PreviousOutPoint] + for _, input := range set.Inputs() { + pi, ok := s.pendingInputs[*input.OutPoint()] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Debugf("Skipped marking input as pending "+ "published: %v not found in pending inputs", - input.PreviousOutPoint) + input.OutPoint()) continue } @@ -863,7 +880,7 @@ func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, if pi.terminated() { log.Errorf("Expect input %v to not have terminated "+ "state, instead it has %v", - input.PreviousOutPoint, pi.state) + input.OutPoint, pi.state) continue } @@ -871,19 +888,9 @@ func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, // Update the input's state. pi.state = StatePendingPublish - // Record the fees and fee rate of this tx to prepare possible - // RBF. - pi.rbf = fn.Some(RBFInfo{ - Txid: tr.Txid, - FeeRate: chainfee.SatPerKWeight(tr.FeeRate), - Fee: btcutil.Amount(tr.Fee), - }) - // Record another publish attempt. pi.publishAttempts++ } - - return nil } // markInputsPublished updates the sweeping tx in db and marks the list of @@ -932,17 +939,16 @@ func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, } // markInputsPublishFailed marks the list of inputs as failed to be published. -func (s *UtxoSweeper) markInputsPublishFailed(inputs []*wire.TxIn) { +func (s *UtxoSweeper) markInputsPublishFailed(outpoints []wire.OutPoint) { // Reschedule sweep. - for _, input := range inputs { - pi, ok := s.pendingInputs[input.PreviousOutPoint] + for _, op := range outpoints { + pi, ok := s.pendingInputs[op] if !ok { // It could be that this input is an additional wallet // input that was attached. In that case there also // isn't a pending input to update. log.Debugf("Skipped marking input as publish failed: "+ - "%v not found in pending inputs", - input.PreviousOutPoint) + "%v not found in pending inputs", op) continue } @@ -950,13 +956,12 @@ func (s *UtxoSweeper) markInputsPublishFailed(inputs []*wire.TxIn) { // Valdiate that the input is in an expected state. if pi.state != StatePendingPublish { log.Errorf("Expect input %v to have %v, instead it "+ - "has %v", input.PreviousOutPoint, - StatePendingPublish, pi.state) + "has %v", op, StatePendingPublish, pi.state) continue } - log.Warnf("Failed to publish input %v", input.PreviousOutPoint) + log.Warnf("Failed to publish input %v", op) // Update the input's state. pi.state = StatePublishFailed @@ -1563,3 +1568,167 @@ func (s *UtxoSweeper) sweepPendingInputs(inputs pendingInputs) { } } } + +// monitorFeeBumpResult subscribes to the passed result chan to listen for +// future updates about the sweeping tx. +// +// NOTE: must run as a goroutine. +func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { + defer s.wg.Done() + + for { + select { + case r := <-resultChan: + // Validate the result is valid. + if err := r.Validate(); err != nil { + log.Errorf("Received invalid result: %v", err) + continue + } + + // Send the result back to the main event loop. + select { + case s.bumpResultChan <- r: + case <-s.quit: + log.Debug("Sweeper shutting down, skip " + + "sending bump result") + + return + } + + // The sweeping tx has been confirmed, we can exit the + // monitor now. + // + // TODO(yy): can instead remove the spend subscription + // in sweeper and rely solely on this event to mark + // inputs as Swept? + if r.Event == TxConfirmed || r.Event == TxFailed { + log.Debugf("Received %v for sweep tx %v, exit "+ + "fee bump monitor", r.Event, + r.Tx.TxHash()) + + return + } + + case <-s.quit: + log.Debugf("Sweeper shutting down, exit fee " + + "bump handler") + + return + } + } +} + +// handleBumpEventTxFailed handles the case where the tx has been failed to +// publish. +func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { + tx, err := r.Tx, r.Err + + log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) + + outpoints := make([]wire.OutPoint, 0, len(tx.TxIn)) + for _, inp := range tx.TxIn { + outpoints = append(outpoints, inp.PreviousOutPoint) + } + + // TODO(yy): should we also remove the failed tx from db? + s.markInputsPublishFailed(outpoints) + + return err +} + +// handleBumpEventTxReplaced handles the case where the sweeping tx has been +// replaced by a new one. +func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error { + oldTx := r.ReplacedTx + newTx := r.Tx + + // Prepare a new record to replace the old one. + tr := &TxRecord{ + Txid: newTx.TxHash(), + FeeRate: uint64(r.FeeRate), + Fee: uint64(r.Fee), + } + + // Get the old record for logging purpose. + oldTxid := oldTx.TxHash() + record, err := s.cfg.Store.GetTx(oldTxid) + if err != nil { + log.Errorf("Fetch tx record for %v: %v", oldTxid, err) + return err + } + + log.Infof("RBFed tx=%v(fee=%v, feerate=%v) with new tx=%v(fee=%v, "+ + "feerate=%v)", record.Txid, record.Fee, record.FeeRate, + tr.Txid, tr.Fee, tr.FeeRate) + + // The old sweeping tx has been replaced by a new one, we will update + // the tx record in the sweeper db. + // + // TODO(yy): we may also need to update the inputs in this tx to a new + // state. Suppose a replacing tx only spends a subset of the inputs + // here, we'd end up with the rest being marked as `StatePublished` and + // won't be aggregated in the next sweep. Atm it's fine as we always + // RBF the same input set. + if err := s.cfg.Store.DeleteTx(oldTxid); err != nil { + log.Errorf("Delete tx record for %v: %v", oldTxid, err) + return err + } + + // Mark the inputs as published using the replacing tx. + return s.markInputsPublished(tr, r.Tx.TxIn) +} + +// handleBumpEventTxPublished handles the case where the sweeping tx has been +// successfully published. +func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { + tx := r.Tx + tr := &TxRecord{ + Txid: tx.TxHash(), + FeeRate: uint64(r.FeeRate), + Fee: uint64(r.Fee), + } + + // Inputs have been successfully published so we update their + // states. + err := s.markInputsPublished(tr, tx.TxIn) + if err != nil { + return err + } + + log.Debugf("Published sweep tx %v, num_inputs=%v, height=%v", + tx.TxHash(), len(tx.TxIn), s.currentHeight) + + // If there's no error, remove the output script. Otherwise + // keep it so that it can be reused for the next transaction + // and causes no address inflation. + s.currentOutputScript = nil + + return nil +} + +// handleBumpEvent handles the result sent from the bumper based on its event +// type. +// +// NOTE: TxConfirmed event is not handled, since we already subscribe to the +// input's spending event, we don't need to do anything here. +func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { + log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash()) + + switch r.Event { + // The tx has been published, we update the inputs' state and create a + // record to be stored in the sweeper db. + case TxPublished: + return s.handleBumpEventTxPublished(r) + + // The tx has failed, we update the inputs' state. + case TxFailed: + return s.handleBumpEventTxFailed(r) + + // The tx has been replaced, we will remove the old tx and replace it + // with the new one. + case TxReplaced: + return s.handleBumpEventTxReplaced(r) + } + + return nil +} diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 519bbdbb2a..ee143e3d31 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -2,8 +2,10 @@ package sweep import ( "errors" + "fmt" "os" "runtime/pprof" + "sync/atomic" "testing" "time" @@ -19,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" lnmock "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/mock" @@ -33,6 +36,8 @@ var ( testMaxInputsPerTx = uint32(3) defaultFeePref = Params{Fee: FeeEstimateInfo{ConfTarget: 1}} + + errDummy = errors.New("dummy error") ) type sweeperTestContext struct { @@ -43,6 +48,7 @@ type sweeperTestContext struct { estimator *mockFeeEstimator backend *mockBackend store SweeperStore + publisher *MockBumper publishChan chan wire.MsgTx currentHeight int32 @@ -50,7 +56,7 @@ type sweeperTestContext struct { var ( spendableInputs []*input.BaseInput - testInputCount int + testInputCount atomic.Uint64 testPubKey, _ = btcec.ParsePubKey([]byte{ 0x04, 0x11, 0xdb, 0x93, 0xe1, 0xdc, 0xdb, 0x8a, @@ -67,7 +73,7 @@ var ( func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput { hash := chainhash.Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - byte(testInputCount + 1)} + byte(testInputCount.Add(1))} input := input.MakeBaseInput( &wire.OutPoint{ @@ -86,8 +92,6 @@ func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput nil, ) - testInputCount++ - return input } @@ -127,6 +131,12 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { testMaxInputsPerTx, ) + // Create a mock fee bumper. + mockBumper := &MockBumper{} + t.Cleanup(func() { + mockBumper.AssertExpectations(t) + }) + ctx := &sweeperTestContext{ notifier: notifier, publishChan: backend.publishChan, @@ -135,6 +145,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { backend: backend, store: store, currentHeight: mockChainHeight, + publisher: mockBumper, } ctx.sweeper = New(&UtxoSweeperConfig{ @@ -153,6 +164,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { MaxSweepAttempts: testMaxSweepAttempts, MaxFeeRate: DefaultMaxFeeRate, Aggregator: aggregator, + Publisher: mockBumper, }) ctx.sweeper.Start() @@ -338,27 +350,80 @@ func assertTxFeeRate(t *testing.T, tx *wire.MsgTx, } } +// assertNumSweeps asserts that the expected number of sweeps has been found in +// the sweeper's store. +func assertNumSweeps(t *testing.T, sweeper *UtxoSweeper, num int) { + err := wait.NoError(func() error { + sweeps, err := sweeper.ListSweeps() + if err != nil { + return err + } + + if len(sweeps) != num { + return fmt.Errorf("want %d sweeps, got %d", + num, len(sweeps)) + } + + return nil + }, 5*time.Second) + require.NoError(t, err, "timeout checking num of sweeps") +} + // TestSuccess tests the sweeper happy flow. func TestSuccess(t *testing.T) { ctx := createSweeperTestContext(t) + inp := spendableInputs[0] + // Sweeping an input without a fee preference should result in an error. - _, err := ctx.sweeper.SweepInput(spendableInputs[0], Params{ + _, err := ctx.sweeper.SweepInput(inp, Params{ Fee: &FeeEstimateInfo{}, }) - if err != ErrNoFeePreference { - t.Fatalf("expected ErrNoFeePreference, got %v", err) - } + require.ErrorIs(t, err, ErrNoFeePreference) + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{{ + PreviousOutPoint: *inp.OutPoint(), + }}, + } - resultChan, err := ctx.sweeper.SweepInput( - spendableInputs[0], defaultFeePref, - ) - if err != nil { - t.Fatal(err) - } + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }) + + resultChan, err := ctx.sweeper.SweepInput(inp, defaultFeePref) + require.NoError(t, err) sweepTx := ctx.receiveTx() + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + + // Mine a block to confirm the sweep tx. ctx.backend.mine() select { @@ -402,15 +467,52 @@ func TestDust(t *testing.T) { // Sweep another input that brings the tx output above the dust limit. largeInput := createTestInput(100000, input.CommitmentTimeLock) + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *largeInput.OutPoint()}, + {PreviousOutPoint: *dustInput.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }) + _, err = ctx.sweeper.SweepInput(&largeInput, defaultFeePref) require.NoError(t, err) // The second input brings the sweep output above the dust limit. We // expect a sweep tx now. - sweepTx := ctx.receiveTx() require.Len(t, sweepTx.TxIn, 2, "unexpected num of tx inputs") + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.backend.mine() ctx.finish(1) @@ -433,29 +535,53 @@ func TestWalletUtxo(t *testing.T) { // sats. The tx yield becomes then 294-180 = 114 sats. dustInput := createTestInput(294, input.WitnessKeyHash) + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *dustInput.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }) + _, err := ctx.sweeper.SweepInput( &dustInput, Params{Fee: FeeEstimateInfo{FeeRate: chainfee.FeePerKwFloor}}, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sweepTx := ctx.receiveTx() - if len(sweepTx.TxIn) != 2 { - t.Fatalf("Expected tx to sweep 2 inputs, but contains %v "+ - "inputs instead", len(sweepTx.TxIn)) - } - // Calculate expected output value based on wallet utxo of 1_000_000 - // sats. - expectedOutputValue := int64(294 + 1_000_000 - 180) - if sweepTx.TxOut[0].Value != expectedOutputValue { - t.Fatalf("Expected output value of %v, but got %v", - expectedOutputValue, sweepTx.TxOut[0].Value) - } + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) ctx.backend.mine() + + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.finish(1) } @@ -470,28 +596,50 @@ func TestNegativeInput(t *testing.T) { largeInputResult, err := ctx.sweeper.SweepInput( &largeInput, defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Sweep an additional input with a negative net yield. The weight of // the HtlcAcceptedRemoteSuccess input type adds more in fees than its // value at the current fee level. negInput := createTestInput(2900, input.HtlcOfferedRemoteTimeout) negInputResult, err := ctx.sweeper.SweepInput(&negInput, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Sweep a third input that has a smaller output than the previous one, // but yields positively because of its lower weight. positiveInput := createTestInput(2800, input.CommitmentNoDelay) + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *largeInput.OutPoint()}, + {PreviousOutPoint: *positiveInput.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + positiveInputResult, err := ctx.sweeper.SweepInput( &positiveInput, defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // We expect that a sweep tx is published now, but it should only // contain the large input. The negative input should stay out of sweeps @@ -499,8 +647,19 @@ func TestNegativeInput(t *testing.T) { sweepTx1 := ctx.receiveTx() assertTxSweepsInputs(t, &sweepTx1, &largeInput, &positiveInput) + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(largeInputResult, nil) ctx.expectResult(positiveInputResult, nil) @@ -509,18 +668,56 @@ func TestNegativeInput(t *testing.T) { // Create another large input. secondLargeInput := createTestInput(100000, input.CommitmentNoDelay) + + // Mock the Broadcast method to succeed. + bumpResultChan = make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *negInput.OutPoint()}, + {PreviousOutPoint: *secondLargeInput. + OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + secondLargeInputResult, err := ctx.sweeper.SweepInput( &secondLargeInput, defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) sweepTx2 := ctx.receiveTx() assertTxSweepsInputs(t, &sweepTx2, &secondLargeInput, &negInput) + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) + ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(secondLargeInputResult, nil) ctx.expectResult(negInputResult, nil) @@ -531,30 +728,96 @@ func TestNegativeInput(t *testing.T) { func TestChunks(t *testing.T) { ctx := createSweeperTestContext(t) + // Mock the Broadcast method to succeed on the first chunk. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + //nolint:lll + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *spendableInputs[0].OutPoint()}, + {PreviousOutPoint: *spendableInputs[1].OutPoint()}, + {PreviousOutPoint: *spendableInputs[2].OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second chunk. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + //nolint:lll + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *spendableInputs[3].OutPoint()}, + {PreviousOutPoint: *spendableInputs[4].OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + // Sweep five inputs. for _, input := range spendableInputs[:5] { _, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) } // We expect two txes to be published because of the max input count of // three. sweepTx1 := ctx.receiveTx() - if len(sweepTx1.TxIn) != 3 { - t.Fatalf("Expected first tx to sweep 3 inputs, but contains %v "+ - "inputs instead", len(sweepTx1.TxIn)) - } + require.Len(t, sweepTx1.TxIn, 3) sweepTx2 := ctx.receiveTx() - if len(sweepTx2.TxIn) != 2 { - t.Fatalf("Expected first tx to sweep 2 inputs, but contains %v "+ - "inputs instead", len(sweepTx1.TxIn)) - } + require.Len(t, sweepTx2.TxIn, 2) + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.finish(1) } @@ -572,39 +835,60 @@ func TestRemoteSpend(t *testing.T) { func testRemoteSpend(t *testing.T, postSweep bool) { ctx := createSweeperTestContext(t) + // Create a fake sweep tx that spends the second input as the first + // will be spent by the remote. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *spendableInputs[1].OutPoint()}, + }, + } + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput( spendableInputs[0], defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resultChan2, err := ctx.sweeper.SweepInput( spendableInputs[1], defaultFeePref, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Spend the input with an unknown tx. remoteTx := &wire.MsgTx{ TxIn: []*wire.TxIn{ - { - PreviousOutPoint: *(spendableInputs[0].OutPoint()), - }, + {PreviousOutPoint: *(spendableInputs[0].OutPoint())}, }, } err = ctx.backend.publishTransaction(remoteTx) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if postSweep { - // Tx publication by sweeper returns ErrDoubleSpend. Sweeper // will retry the inputs without reporting a result. It could be // spent by the remote party. ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) } ctx.backend.mine() @@ -624,13 +908,21 @@ func testRemoteSpend(t *testing.T, postSweep bool) { if !postSweep { // Assert that the sweeper sweeps the remaining input. sweepTx := ctx.receiveTx() + require.Len(t, sweepTx.TxIn, 1) - if len(sweepTx.TxIn) != 1 { - t.Fatal("expected sweep to only sweep the one remaining output") - } + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan2, nil) ctx.finish(1) @@ -640,8 +932,10 @@ func testRemoteSpend(t *testing.T, postSweep bool) { ctx.finish(2) select { - case <-resultChan2: - t.Fatalf("no result expected for error input") + case r := <-resultChan2: + require.NoError(t, r.Err) + require.Equal(t, r.Tx.TxHash(), tx.TxHash()) + default: } } @@ -653,26 +947,58 @@ func TestIdempotency(t *testing.T) { ctx := createSweeperTestContext(t) input := spendableInputs[0] + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resultChan2, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + sweepTx := ctx.receiveTx() - ctx.receiveTx() + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) resultChan3, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Spend the input of the sweep tx. ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan1, nil) ctx.expectResult(resultChan2, nil) ctx.expectResult(resultChan3, nil) @@ -683,9 +1009,7 @@ func TestIdempotency(t *testing.T) { // Because the sweeper kept track of all of its sweep txes, it will // recognize the spend as its own. resultChan4, err := ctx.sweeper.SweepInput(input, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) ctx.expectResult(resultChan4, nil) // Timer is still running, but spend notification was delivered before @@ -708,25 +1032,78 @@ func TestRestart(t *testing.T) { // Sweep input and expect sweep tx. input1 := spendableInputs[0] + + // Mock the Broadcast method to succeed. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + _, err := ctx.sweeper.SweepInput(input1, defaultFeePref) require.NoError(t, err) - ctx.receiveTx() + sweepTx1 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) // Restart sweeper. ctx.restartSweeper() // Simulate other subsystem (e.g. contract resolver) re-offering inputs. spendChan1, err := ctx.sweeper.SweepInput(input1, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) input2 := spendableInputs[1] + + // Mock the Broadcast method to succeed. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + spendChan2, err := ctx.sweeper.SweepInput(input2, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Spend inputs of sweep txes and verify that spend channels signal // spends. @@ -745,10 +1122,27 @@ func TestRestart(t *testing.T) { // Timer tick should trigger republishing a sweep for the remaining // input. - ctx.receiveTx() + sweepTx2 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + select { case result := <-spendChan2: if result.Err != nil { @@ -769,51 +1163,104 @@ func TestRestart(t *testing.T) { func TestRestartRemoteSpend(t *testing.T) { ctx := createSweeperTestContext(t) - // Sweep input. + // Get testing inputs. input1 := spendableInputs[0] + input2 := spendableInputs[1] + + // Create a fake sweep tx that spends the second input as the first + // will be spent by the remote. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + _, err := ctx.sweeper.SweepInput(input1, defaultFeePref) require.NoError(t, err) // Sweep another input. - input2 := spendableInputs[1] _, err = ctx.sweeper.SweepInput(input2, defaultFeePref) require.NoError(t, err) sweepTx := ctx.receiveTx() + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + // Restart sweeper. ctx.restartSweeper() - // Replace the sweep tx with a remote tx spending input 1. + // Replace the sweep tx with a remote tx spending input 2. ctx.backend.deleteUnconfirmed(sweepTx.TxHash()) remoteTx := &wire.MsgTx{ TxIn: []*wire.TxIn{ - { - PreviousOutPoint: *(input2.OutPoint()), - }, + {PreviousOutPoint: *input1.OutPoint()}, }, } - if err := ctx.backend.publishTransaction(remoteTx); err != nil { - t.Fatal(err) - } + err = ctx.backend.publishTransaction(remoteTx) + require.NoError(t, err) // Mine remote spending tx. ctx.backend.mine() + // Mock the Broadcast method to succeed. + bumpResultChan = make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + // Simulate other subsystem (e.g. contract resolver) re-offering input - // 0. - spendChan, err := ctx.sweeper.SweepInput(input1, defaultFeePref) - if err != nil { - t.Fatal(err) - } + // 2. + spendChan, err := ctx.sweeper.SweepInput(input2, defaultFeePref) + require.NoError(t, err) // Expect sweeper to construct a new tx, because input 1 was spend // remotely. - ctx.receiveTx() + sweepTx = ctx.receiveTx() ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(spendChan, nil) ctx.finish(1) @@ -826,11 +1273,40 @@ func TestRestartConfirmed(t *testing.T) { // Sweep input. input := spendableInputs[0] - if _, err := ctx.sweeper.SweepInput(input, defaultFeePref); err != nil { - t.Fatal(err) - } - ctx.receiveTx() + // Mock the Broadcast method to succeed. + bumpResultChan := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + _, err := ctx.sweeper.SweepInput(input, defaultFeePref) + require.NoError(t, err) + + sweepTx := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) // Restart sweeper. ctx.restartSweeper() @@ -838,9 +1314,18 @@ func TestRestartConfirmed(t *testing.T) { // Mine the sweep tx. ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx, + FeeRate: 10, + Fee: 100, + } + // Simulate other subsystem (e.g. contract resolver) re-offering input // 0. spendChan, err := ctx.sweeper.SweepInput(input, defaultFeePref) + require.NoError(t, err) if err != nil { t.Fatal(err) } @@ -855,29 +1340,96 @@ func TestRestartConfirmed(t *testing.T) { func TestRetry(t *testing.T) { ctx := createSweeperTestContext(t) - resultChan0, err := ctx.sweeper.SweepInput( - spendableInputs[0], defaultFeePref, - ) - if err != nil { - t.Fatal(err) - } + inp0 := spendableInputs[0] + inp1 := spendableInputs[1] + + // Mock the Broadcast method to succeed. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *inp0.OutPoint()}, + }, + } - // We expect a sweep to be published. - ctx.receiveTx() + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + resultChan0, err := ctx.sweeper.SweepInput(inp0, defaultFeePref) + require.NoError(t, err) + + // We expect a sweep to be published. + sweepTx1 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 1) + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *inp1.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() // Offer a fresh input. - resultChan1, err := ctx.sweeper.SweepInput( - spendableInputs[1], defaultFeePref, - ) - if err != nil { - t.Fatal(err) - } + resultChan1, err := ctx.sweeper.SweepInput(inp1, defaultFeePref) + require.NoError(t, err) // A single tx is expected to be published. - ctx.receiveTx() + sweepTx2 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan0, nil) ctx.expectResult(resultChan1, nil) @@ -903,44 +1455,105 @@ func TestDifferentFeePreferences(t *testing.T) { ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate input1 := spendableInputs[0] + input2 := spendableInputs[1] + input3 := spendableInputs[2] + + // Mock the Broadcast method to succeed on the first sweep. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input3.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput( input1, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input2 := spendableInputs[1] + require.NoError(t, err) + resultChan2, err := ctx.sweeper.SweepInput( input2, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input3 := spendableInputs[2] + require.NoError(t, err) + resultChan3, err := ctx.sweeper.SweepInput( input3, Params{Fee: lowFeePref}, ) - if err != nil { - t.Fatal(err) - } - - // Generate the same type of sweep script that was used for weight - // estimation. - changePk, err := ctx.sweeper.cfg.GenSweepScript() require.NoError(t, err) - // The first transaction broadcast should be the one spending the higher - // fee rate inputs. + // The first transaction broadcast should be the one spending the + // higher fee rate inputs. sweepTx1 := ctx.receiveTx() - assertTxFeeRate(t, &sweepTx1, highFeeRate, changePk, input1, input2) // The second should be the one spending the lower fee rate inputs. sweepTx2 := ctx.receiveTx() - assertTxFeeRate(t, &sweepTx2, lowFeeRate, changePk, input3) + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) // With the transactions broadcast, we'll mine a block to so that the // result is delivered to each respective client. ctx.backend.mine() + + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + resultChans := []chan Result{resultChan1, resultChan2, resultChan3} for _, resultChan := range resultChans { ctx.expectResult(resultChan, nil) @@ -974,37 +1587,105 @@ func TestPendingInputs(t *testing.T) { ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate input1 := spendableInputs[0] + input2 := spendableInputs[1] + input3 := spendableInputs[2] + + // Mock the Broadcast method to succeed on the first sweep. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + {PreviousOutPoint: *input2.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input3.OutPoint()}, + }, + } + + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + resultChan1, err := ctx.sweeper.SweepInput( input1, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input2 := spendableInputs[1] + require.NoError(t, err) + _, err = ctx.sweeper.SweepInput( input2, Params{Fee: highFeePref}, ) - if err != nil { - t.Fatal(err) - } - input3 := spendableInputs[2] + require.NoError(t, err) + resultChan3, err := ctx.sweeper.SweepInput( input3, Params{Fee: lowFeePref}, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // We should expect to see all inputs pending. ctx.assertPendingInputs(input1, input2, input3) // We should expect to see both sweep transactions broadcast - one for // the higher feerate, the other for the lower. - ctx.receiveTx() - ctx.receiveTx() + sweepTx1 := ctx.receiveTx() + sweepTx2 := ctx.receiveTx() + + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 2) // Mine these txns, and we should expect to see the results delivered. ctx.backend.mine() + + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx2, + FeeRate: 10, + Fee: 100, + } + ctx.expectResult(resultChan1, nil) ctx.expectResult(resultChan3, nil) ctx.assertPendingInputs() @@ -1012,79 +1693,91 @@ func TestPendingInputs(t *testing.T) { ctx.finish(1) } -// TestBumpFeeRBF ensures that the UtxoSweeper can properly handle a fee bump -// request for an input it is currently attempting to sweep. When sweeping the -// input with the higher fee rate, a replacement transaction is created. -func TestBumpFeeRBF(t *testing.T) { +// TestExclusiveGroup tests the sweeper exclusive group functionality. +func TestExclusiveGroup(t *testing.T) { ctx := createSweeperTestContext(t) - lowFeePref := FeeEstimateInfo{ConfTarget: 144} - lowFeeRate := chainfee.FeePerKwFloor - ctx.estimator.blocksToFee[lowFeePref.ConfTarget] = lowFeeRate - - // We'll first try to bump the fee of an output currently unknown to the - // UtxoSweeper. Doing so should result in a lnwallet.ErrNotMine error. - _, err := ctx.sweeper.UpdateParams( - wire.OutPoint{}, ParamsUpdate{Fee: lowFeePref}, - ) - if err != lnwallet.ErrNotMine { - t.Fatalf("expected error lnwallet.ErrNotMine, got \"%v\"", err) - } - - // We'll then attempt to sweep an input, which we'll use to bump its fee - // later on. - input := createTestInput( - btcutil.SatoshiPerBitcoin, input.CommitmentTimeLock, - ) - sweepResult, err := ctx.sweeper.SweepInput( - &input, Params{Fee: lowFeePref}, - ) - if err != nil { - t.Fatal(err) - } - - // Generate the same type of change script used so we can have accurate - // weight estimation. - changePk, err := ctx.sweeper.cfg.GenSweepScript() - require.NoError(t, err) - - // Ensure that a transaction is broadcast with the lower fee preference. - lowFeeTx := ctx.receiveTx() - assertTxFeeRate(t, &lowFeeTx, lowFeeRate, changePk, &input) + input1 := spendableInputs[0] + input2 := spendableInputs[1] + input3 := spendableInputs[2] - // We'll then attempt to bump its fee rate. - highFeePref := FeeEstimateInfo{ConfTarget: 6} - highFeeRate := DefaultMaxFeeRate.FeePerKWeight() - ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate + // Mock the Broadcast method to succeed on the first sweep. + bumpResultChan1 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan1, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input1.OutPoint()}, + }, + } - // We should expect to see an error if a fee preference isn't provided. - _, err = ctx.sweeper.UpdateParams(*input.OutPoint(), ParamsUpdate{ - Fee: &FeeEstimateInfo{}, - }) - if err != ErrNoFeePreference { - t.Fatalf("expected ErrNoFeePreference, got %v", err) - } + // Send the first event. + bumpResultChan1 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } - bumpResult, err := ctx.sweeper.UpdateParams( - *input.OutPoint(), ParamsUpdate{Fee: highFeePref}, - ) - require.NoError(t, err, "unable to bump input's fee") + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the second sweep. + bumpResultChan2 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan2, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input2.OutPoint()}, + }, + } - // A higher fee rate transaction should be immediately broadcast. - highFeeTx := ctx.receiveTx() - assertTxFeeRate(t, &highFeeTx, highFeeRate, changePk, &input) + // Send the first event. + bumpResultChan2 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } - // We'll finish our test by mining the sweep transaction. - ctx.backend.mine() - ctx.expectResult(sweepResult, nil) - ctx.expectResult(bumpResult, nil) + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + + // Mock the Broadcast method to succeed on the third sweep. + bumpResultChan3 := make(chan *BumpResult, 1) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan3, nil).Run(func(args mock.Arguments) { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: *input3.OutPoint()}, + }, + } - ctx.finish(1) -} + // Send the first event. + bumpResultChan3 <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } -// TestExclusiveGroup tests the sweeper exclusive group functionality. -func TestExclusiveGroup(t *testing.T) { - ctx := createSweeperTestContext(t) + // Due to a mix of new and old test frameworks, we need to + // manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them will + // mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() // Sweep three inputs in the same exclusive group. var results []chan Result @@ -1096,32 +1789,45 @@ func TestExclusiveGroup(t *testing.T) { ExclusiveGroup: &exclusiveGroup, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) results = append(results, result) } // We expect all inputs to be published in separate transactions, even // though they share the same fee preference. - for i := 0; i < 3; i++ { - sweepTx := ctx.receiveTx() - if len(sweepTx.TxOut) != 1 { - t.Fatal("expected a single tx out in the sweep tx") - } + sweepTx1 := ctx.receiveTx() + require.Len(t, sweepTx1.TxIn, 1) + + sweepTx2 := ctx.receiveTx() + sweepTx3 := ctx.receiveTx() - // Remove all txes except for the one that sweeps the first - // input. This simulates the sweeps being conflicting. - if sweepTx.TxIn[0].PreviousOutPoint != - *spendableInputs[0].OutPoint() { + // Remove all txes except for the one that sweeps the first + // input. This simulates the sweeps being conflicting. + ctx.backend.deleteUnconfirmed(sweepTx2.TxHash()) + ctx.backend.deleteUnconfirmed(sweepTx3.TxHash()) - ctx.backend.deleteUnconfirmed(sweepTx.TxHash()) - } - } + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 3) // Mine the first sweep tx. ctx.backend.mine() + // Mock a confirmed event. + bumpResultChan1 <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTx1, + FeeRate: 10, + Fee: 100, + } + bumpResultChan2 <- &BumpResult{ + Event: TxFailed, + Tx: &sweepTx2, + } + bumpResultChan2 <- &BumpResult{ + Event: TxFailed, + Tx: &sweepTx3, + } + // Expect the first input to be swept by the confirmed sweep tx. result0 := <-results[0] if result0.Err != nil { @@ -1141,69 +1847,6 @@ func TestExclusiveGroup(t *testing.T) { } } -// TestCpfp tests that the sweeper spends cpfp inputs at a fee rate that exceeds -// the parent tx fee rate. -func TestCpfp(t *testing.T) { - ctx := createSweeperTestContext(t) - - ctx.estimator.updateFees(1000, chainfee.FeePerKwFloor) - - // Offer an input with an unconfirmed parent tx to the sweeper. The - // parent tx pays 3000 sat/kw. - hash := chainhash.Hash{1} - input := input.MakeBaseInput( - &wire.OutPoint{Hash: hash}, - input.CommitmentTimeLock, - &input.SignDescriptor{ - Output: &wire.TxOut{ - Value: 330, - }, - KeyDesc: keychain.KeyDescriptor{ - PubKey: testPubKey, - }, - }, - 0, - &input.TxInfo{ - Weight: 300, - Fee: 900, - }, - ) - - feePref := FeeEstimateInfo{ConfTarget: 6} - result, err := ctx.sweeper.SweepInput( - &input, Params{Fee: feePref, Force: true}, - ) - require.NoError(t, err) - - // Increase the fee estimate to above the parent tx fee rate. - ctx.estimator.updateFees(5000, chainfee.FeePerKwFloor) - - // Signal a new block. This is a trigger for the sweeper to refresh fee - // estimates. - ctx.notifier.NotifyEpoch(1000) - - // Now we do expect a sweep transaction to be published with our input - // and an attached wallet utxo. - tx := ctx.receiveTx() - require.Len(t, tx.TxIn, 2) - require.Len(t, tx.TxOut, 1) - - // As inputs we have 10000 sats from the wallet and 330 sats from the - // cpfp input. The sweep tx is weight expected to be 759 units. There is - // an additional 300 weight units from the parent to include in the - // package, making a total of 1059. At 5000 sat/kw, the required fee for - // the package is 5295 sats. The parent already paid 900 sats, so there - // is 4395 sat remaining to be paid. The expected output value is - // therefore: 1_000_000 + 330 - 4395 = 995 935. - require.Equal(t, int64(995_935), tx.TxOut[0].Value) - - // Mine the tx and assert that the result is passed back. - ctx.backend.mine() - ctx.expectResult(result, nil) - - ctx.finish(1) -} - type testInput struct { *input.BaseInput @@ -1299,8 +1942,10 @@ func TestLockTimes(t *testing.T) { // Sweep 8 inputs, using 4 different lock times. var ( - results []chan Result - inputs = make(map[wire.OutPoint]input.Input) + results []chan Result + inputs = make(map[wire.OutPoint]input.Input) + clusters = make(map[uint32][]input.Input) + bumpResultChans = make([]chan *BumpResult, 0, 4) ) for i := 0; i < numSweeps*2; i++ { lt := uint32(10 + (i % numSweeps)) @@ -1309,53 +1954,84 @@ func TestLockTimes(t *testing.T) { locktime: <, } - result, err := ctx.sweeper.SweepInput( - inp, Params{ - Fee: FeeEstimateInfo{ConfTarget: 6}, - }, - ) - if err != nil { - t.Fatal(err) - } - results = append(results, result) - op := inp.OutPoint() inputs[*op] = inp + + cluster, ok := clusters[lt] + if !ok { + cluster = make([]input.Input, 0) + } + cluster = append(cluster, inp) + clusters[lt] = cluster } - // We also add 3 regular inputs that don't require any specific lock - // time. for i := 0; i < 3; i++ { inp := spendableInputs[i+numSweeps*2] + inputs[*inp.OutPoint()] = inp + + lt := uint32(10 + (i % numSweeps)) + clusters[lt] = append(clusters[lt], inp) + } + + for lt, cluster := range clusters { + // Create a fake sweep tx. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{}, + LockTime: lt, + } + + // Append the inputs. + for _, inp := range cluster { + txIn := &wire.TxIn{ + PreviousOutPoint: *inp.OutPoint(), + } + tx.TxIn = append(tx.TxIn, txIn) + } + + // Mock the Broadcast method to succeed on current sweep. + bumpResultChan := make(chan *BumpResult, 1) + bumpResultChans = append(bumpResultChans, bumpResultChan) + ctx.publisher.On("Broadcast", mock.Anything).Return( + bumpResultChan, nil).Run(func(args mock.Arguments) { + // Send the first event. + bumpResultChan <- &BumpResult{ + Event: TxPublished, + Tx: tx, + } + + // Due to a mix of new and old test frameworks, we need + // to manually call the method to get the test to pass. + // + // TODO(yy): remove the test context and replace them + // will mocks. + err := ctx.backend.PublishTransaction(tx, "") + require.NoError(t, err) + }).Once() + } + + // Make all the sweeps. + for _, inp := range inputs { result, err := ctx.sweeper.SweepInput( inp, Params{ Fee: FeeEstimateInfo{ConfTarget: 6}, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) results = append(results, result) - - op := inp.OutPoint() - inputs[*op] = inp } // Check the sweeps transactions, ensuring all inputs are there, and // all the locktimes are satisfied. + sweepTxes := make([]wire.MsgTx, 0, numSweeps) for i := 0; i < numSweeps; i++ { sweepTx := ctx.receiveTx() - if len(sweepTx.TxOut) != 1 { - t.Fatal("expected a single tx out in the sweep tx") - } + sweepTxes = append(sweepTxes, sweepTx) for _, txIn := range sweepTx.TxIn { op := txIn.PreviousOutPoint inp, ok := inputs[op] - if !ok { - t.Fatalf("Unexpected outpoint: %v", op) - } + require.True(t, ok) delete(inputs, op) @@ -1366,25 +2042,33 @@ func TestLockTimes(t *testing.T) { continue } - if lt != sweepTx.LockTime { - t.Fatalf("Input required locktime %v, sweep "+ - "tx had locktime %v", lt, sweepTx.LockTime) - } + require.EqualValues(t, lt, sweepTx.LockTime) } } - // The should be no inputs not foud in any of the sweeps. - if len(inputs) != 0 { - t.Fatalf("had unsweeped inputs: %v", inputs) - } + // Wait until the sweep tx has been saved to db. + assertNumSweeps(t, ctx.sweeper, 4) - // Mine the first sweeps + // Mine the sweeps. ctx.backend.mine() + for i, bumpResultChan := range bumpResultChans { + // Mock a confirmed event. + bumpResultChan <- &BumpResult{ + Event: TxConfirmed, + Tx: &sweepTxes[i], + FeeRate: 10, + Fee: 100, + } + } + + // The should be no inputs not foud in any of the sweeps. + require.Empty(t, inputs) + // Results should all come back. - for i := range results { + for i, resultChan := range results { select { - case result := <-results[i]: + case result := <-resultChan: require.NoError(t, result.Err) case <-time.After(1 * time.Second): t.Fatalf("result %v did not come back", i) @@ -1392,463 +2076,6 @@ func TestLockTimes(t *testing.T) { } } -// TestRequiredTxOuts checks that inputs having a required TxOut gets swept with -// sweep transactions paying into these outputs. -func TestRequiredTxOuts(t *testing.T) { - // Create some test inputs and locktime vars. - var inputs []*input.BaseInput - for i := 0; i < 20; i++ { - input := createTestInput( - int64(btcutil.SatoshiPerBitcoin+i*500), - input.CommitmentTimeLock, - ) - - inputs = append(inputs, &input) - } - - locktime1 := uint32(51) - locktime2 := uint32(52) - locktime3 := uint32(53) - - aPkScript := make([]byte, input.P2WPKHSize) - aPkScript[0] = 'a' - - bPkScript := make([]byte, input.P2WSHSize) - bPkScript[0] = 'b' - - cPkScript := make([]byte, input.P2PKHSize) - cPkScript[0] = 'c' - - dPkScript := make([]byte, input.P2SHSize) - dPkScript[0] = 'd' - - ePkScript := make([]byte, input.UnknownWitnessSize) - ePkScript[0] = 'e' - - fPkScript := make([]byte, input.P2WSHSize) - fPkScript[0] = 'f' - - testCases := []struct { - name string - inputs []*testInput - assertSweeps func(*testing.T, map[wire.OutPoint]*testInput, - []*wire.MsgTx) - }{ - { - // Single input with a required TX out that is smaller. - // We expect a change output to be added. - name: "single input, leftover change", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: 100000, - }, - }, - }, - - // Since the required output value is small, we expect - // the rest after fees to go into a change output. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 1, len(tx.TxIn)) - - // We should have two outputs, the required - // output must be the first one. - require.Equal(t, 2, len(tx.TxOut)) - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal(t, int64(100000), out.Value) - }, - }, - { - // An input committing to a slightly smaller output, so - // it will pay its own fees. - name: "single input, no change", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - - // Fee will be about 5340 sats. - // Subtract a bit more to - // ensure no dust change output - // is manifested. - Value: inputs[0].SignDesc().Output.Value - 6300, - }, - }, - }, - - // We expect this single input/output pair. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 1, len(tx.TxIn)) - - require.Equal(t, 1, len(tx.TxOut)) - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal( - t, - inputs[0].SignDesc().Output.Value-6300, - out.Value, - ) - }, - }, - { - // Two inputs, where the first one required no tx out. - name: "two inputs, one with required tx out", - inputs: []*testInput{ - { - - // We add a normal, non-requiredTxOut - // input. We use test input 10, to make - // sure this has a higher yield than - // the other input, and will be - // attempted added first to the sweep - // tx. - BaseInput: inputs[10], - }, - { - // The second input requires a TxOut. - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - }, - - // We expect the inputs to have been reordered. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 2, len(tx.TxIn)) - require.Equal(t, 2, len(tx.TxOut)) - - // The required TxOut should be the first one. - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal( - t, inputs[0].SignDesc().Output.Value, - out.Value, - ) - - // The first input should be the one having the - // required TxOut. - require.Len(t, tx.TxIn, 2) - require.Equal( - t, inputs[0].OutPoint(), - &tx.TxIn[0].PreviousOutPoint, - ) - - // Second one is the one without a required tx - // out. - require.Equal( - t, inputs[10].OutPoint(), - &tx.TxIn[1].PreviousOutPoint, - ) - }, - }, - - { - // An input committing to an output of equal value, just - // add input to pay fees. - name: "single input, extra fee input", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - }, - - // We expect an extra input and output. - assertSweeps: func(t *testing.T, - _ map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 2, len(tx.TxIn)) - - require.Equal(t, 2, len(tx.TxOut)) - out := tx.TxOut[0] - require.Equal(t, aPkScript, out.PkScript) - require.Equal( - t, inputs[0].SignDesc().Output.Value, - out.Value, - ) - }, - }, - { - // Three inputs added, should be combined into a single - // sweep. - name: "three inputs", - inputs: []*testInput{ - { - BaseInput: inputs[0], - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[1], - reqTxOut: &wire.TxOut{ - PkScript: bPkScript, - Value: inputs[1].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[2], - reqTxOut: &wire.TxOut{ - PkScript: cPkScript, - Value: inputs[2].SignDesc().Output.Value, - }, - }, - }, - - // We expect an extra input and output to pay fees. - assertSweeps: func(t *testing.T, - testInputs map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 1, len(txs)) - - tx := txs[0] - require.Equal(t, 4, len(tx.TxIn)) - require.Equal(t, 4, len(tx.TxOut)) - - // The inputs and outputs must be in the same - // order. - for i, in := range tx.TxIn { - // Last one is the change input/output - // pair, so we'll skip it. - if i == 3 { - continue - } - - // Get this input to ensure the output - // on index i coresponsd to this one. - inp := testInputs[in.PreviousOutPoint] - require.NotNil(t, inp) - - require.Equal( - t, tx.TxOut[i].Value, - inp.SignDesc().Output.Value, - ) - } - }, - }, - { - // Six inputs added, which 3 different locktimes. - // Should result in 3 sweeps. - name: "six inputs", - inputs: []*testInput{ - { - BaseInput: inputs[0], - locktime: &locktime1, - reqTxOut: &wire.TxOut{ - PkScript: aPkScript, - Value: inputs[0].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[1], - locktime: &locktime1, - reqTxOut: &wire.TxOut{ - PkScript: bPkScript, - Value: inputs[1].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[2], - locktime: &locktime2, - reqTxOut: &wire.TxOut{ - PkScript: cPkScript, - Value: inputs[2].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[3], - locktime: &locktime2, - reqTxOut: &wire.TxOut{ - PkScript: dPkScript, - Value: inputs[3].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[4], - locktime: &locktime3, - reqTxOut: &wire.TxOut{ - PkScript: ePkScript, - Value: inputs[4].SignDesc().Output.Value, - }, - }, - { - BaseInput: inputs[5], - locktime: &locktime3, - reqTxOut: &wire.TxOut{ - PkScript: fPkScript, - Value: inputs[5].SignDesc().Output.Value, - }, - }, - }, - - // We expect three sweeps, each having two of our - // inputs, one extra input and output to pay fees. - assertSweeps: func(t *testing.T, - testInputs map[wire.OutPoint]*testInput, - txs []*wire.MsgTx) { - - require.Equal(t, 3, len(txs)) - - for _, tx := range txs { - require.Equal(t, 3, len(tx.TxIn)) - require.Equal(t, 3, len(tx.TxOut)) - - // The inputs and outputs must be in - // the same order. - for i, in := range tx.TxIn { - // Last one is the change - // output, so we'll skip it. - if i == 2 { - continue - } - - // Get this input to ensure the - // output on index i coresponsd - // to this one. - inp := testInputs[in.PreviousOutPoint] - require.NotNil(t, inp) - - require.Equal( - t, tx.TxOut[i].Value, - inp.SignDesc().Output.Value, - ) - - // Check that the locktimes are - // kept intact. - require.Equal( - t, tx.LockTime, - *inp.locktime, - ) - } - } - }, - }, - } - - for _, testCase := range testCases { - testCase := testCase - - t.Run(testCase.name, func(t *testing.T) { - ctx := createSweeperTestContext(t) - - // We increase the number of max inputs to a tx so that - // won't impact our test. - ctx.sweeper.cfg.MaxInputsPerTx = 100 - - // Sweep all test inputs. - var ( - inputs = make(map[wire.OutPoint]*testInput) - results = make(map[wire.OutPoint]chan Result) - ) - for _, inp := range testCase.inputs { - result, err := ctx.sweeper.SweepInput( - inp, Params{ - Fee: FeeEstimateInfo{ - ConfTarget: 6, - }, - }, - ) - if err != nil { - t.Fatal(err) - } - - op := inp.OutPoint() - results[*op] = result - inputs[*op] = inp - } - - // Send a new block epoch to trigger the sweeper to - // sweep the inputs. - ctx.notifier.NotifyEpoch(ctx.sweeper.currentHeight + 1) - - // Check the sweeps transactions, ensuring all inputs - // are there, and all the locktimes are satisfied. - var sweeps []*wire.MsgTx - Loop: - for { - select { - case tx := <-ctx.publishChan: - sweeps = append(sweeps, &tx) - case <-time.After(200 * time.Millisecond): - break Loop - } - } - - // Mine the sweeps. - ctx.backend.mine() - - // Results should all come back. - for _, resultChan := range results { - result := <-resultChan - if result.Err != nil { - t.Fatalf("expected input to be "+ - "swept: %v", result.Err) - } - } - - // Assert the transactions are what we expect. - testCase.assertSweeps(t, inputs, sweeps) - - // Finally we assert that all our test inputs were part - // of the sweeps, and that they were signed correctly. - sweptInputs := make(map[wire.OutPoint]struct{}) - for _, sweep := range sweeps { - swept := assertSignedIndex(t, sweep, inputs) - for op := range swept { - if _, ok := sweptInputs[op]; ok { - t.Fatalf("outpoint %v part of "+ - "previous sweep", op) - } - - sweptInputs[op] = struct{}{} - } - } - - require.Equal(t, len(inputs), len(sweptInputs)) - for op := range sweptInputs { - _, ok := inputs[op] - if !ok { - t.Fatalf("swept input %v not part of "+ - "test inputs", op) - } - } - }) - } -} - // TestSweeperShutdownHandling tests that we notify callers when the sweeper // cannot handle requests since it's in the process of shutting down. func TestSweeperShutdownHandling(t *testing.T) { @@ -1897,87 +2124,74 @@ func TestMarkInputsPendingPublish(t *testing.T) { require := require.New(t) - // Create a mock sweeper store. - mockStore := NewMockSweeperStore() - - // Create a test TxRecord and a dummy error. - dummyTR := &TxRecord{} - dummyErr := errors.New("dummy error") - // Create a test sweeper. - s := New(&UtxoSweeperConfig{ - Store: mockStore, - }) + s := New(&UtxoSweeperConfig{}) + + // Create a mock input set. + set := &MockInputSet{} + defer set.AssertExpectations(t) // Create three testing inputs. // // inputNotExist specifies an input that's not found in the sweeper's // `pendingInputs` map. - inputNotExist := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 1}, - } + inputNotExist := &input.MockInput{} + defer inputNotExist.AssertExpectations(t) + + inputNotExist.On("OutPoint").Return(&wire.OutPoint{Index: 0}) // inputInit specifies a newly created input. - inputInit := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 2}, - } - s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + inputInit := &input.MockInput{} + defer inputInit.AssertExpectations(t) + + inputInit.On("OutPoint").Return(&wire.OutPoint{Index: 1}) + + s.pendingInputs[*inputInit.OutPoint()] = &pendingInput{ state: StateInit, } // inputPendingPublish specifies an input that's about to be published. - inputPendingPublish := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 3}, - } - s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + inputPendingPublish := &input.MockInput{} + defer inputPendingPublish.AssertExpectations(t) + + inputPendingPublish.On("OutPoint").Return(&wire.OutPoint{Index: 2}) + + s.pendingInputs[*inputPendingPublish.OutPoint()] = &pendingInput{ state: StatePendingPublish, } // inputTerminated specifies an input that's terminated. - inputTerminated := &wire.TxIn{ - PreviousOutPoint: wire.OutPoint{Index: 4}, - } - s.pendingInputs[inputTerminated.PreviousOutPoint] = &pendingInput{ - state: StateExcluded, - } + inputTerminated := &input.MockInput{} + defer inputTerminated.AssertExpectations(t) - // First, check that when an error is returned from db, it's properly - // returned here. - mockStore.On("StoreTx", dummyTR).Return(dummyErr).Once() - err := s.markInputsPendingPublish(dummyTR, nil) - require.ErrorIs(err, dummyErr) + inputTerminated.On("OutPoint").Return(&wire.OutPoint{Index: 3}) - // Then, check that the target input has will be correctly marked as - // published. - // - // Mock the store to return nil - mockStore.On("StoreTx", dummyTR).Return(nil).Once() + s.pendingInputs[*inputTerminated.OutPoint()] = &pendingInput{ + state: StateExcluded, + } // Mark the test inputs. We expect the non-exist input and the // inputTerminated to be skipped, and the rest to be marked as pending // publish. - err = s.markInputsPendingPublish(dummyTR, []*wire.TxIn{ + set.On("Inputs").Return([]input.Input{ inputNotExist, inputInit, inputPendingPublish, inputTerminated, }) - require.NoError(err) + s.markInputsPendingPublish(set) // We expect unchanged number of pending inputs. require.Len(s.pendingInputs, 3) // We expect the init input's state to become pending publish. require.Equal(StatePendingPublish, - s.pendingInputs[inputInit.PreviousOutPoint].state) + s.pendingInputs[*inputInit.OutPoint()].state) // We expect the pending-publish to stay unchanged. require.Equal(StatePendingPublish, - s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + s.pendingInputs[*inputPendingPublish.OutPoint()].state) // We expect the terminated to stay unchanged. require.Equal(StateExcluded, - s.pendingInputs[inputTerminated.PreviousOutPoint].state) - - // Assert mocked statements are executed as expected. - mockStore.AssertExpectations(t) + s.pendingInputs[*inputTerminated.OutPoint()].state) } // TestMarkInputsPublished checks that given a list of inputs with different @@ -2108,8 +2322,10 @@ func TestMarkInputsPublishFailed(t *testing.T) { // Mark the test inputs. We expect the non-exist input and the // inputInit to be skipped, and the final input to be marked as // published. - s.markInputsPublishFailed([]*wire.TxIn{ - inputNotExist, inputInit, inputPendingPublish, + s.markInputsPublishFailed([]wire.OutPoint{ + inputNotExist.PreviousOutPoint, + inputInit.PreviousOutPoint, + inputPendingPublish.PreviousOutPoint, }) // We expect unchanged number of pending inputs. @@ -2421,16 +2637,27 @@ func TestSweepPendingInputs(t *testing.T) { // Create a mock wallet and aggregator. wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + aggregator := &mockUtxoAggregator{} + defer aggregator.AssertExpectations(t) + + publisher := &MockBumper{} + defer publisher.AssertExpectations(t) // Create a test sweeper. s := New(&UtxoSweeperConfig{ Wallet: wallet, Aggregator: aggregator, + Publisher: publisher, + GenSweepScript: func() ([]byte, error) { + return testPubKey.SerializeCompressed(), nil + }, }) // Create an input set that needs wallet inputs. setNeedWallet := &MockInputSet{} + defer setNeedWallet.AssertExpectations(t) // Mock this set to ask for wallet input. setNeedWallet.On("NeedWalletInput").Return(true).Once() @@ -2441,15 +2668,18 @@ func TestSweepPendingInputs(t *testing.T) { // Create an input set that doesn't need wallet inputs. normalSet := &MockInputSet{} + defer normalSet.AssertExpectations(t) + normalSet.On("NeedWalletInput").Return(false).Once() // Mock the methods used in `sweep`. This is not important for this // unit test. - feeRate := chainfee.SatPerKWeight(1000) - setNeedWallet.On("Inputs").Return(nil).Once() - setNeedWallet.On("FeeRate").Return(feeRate).Once() - normalSet.On("Inputs").Return(nil).Once() - normalSet.On("FeeRate").Return(feeRate).Once() + setNeedWallet.On("Inputs").Return(nil).Times(4) + setNeedWallet.On("DeadlineHeight").Return(fn.None[int32]()).Once() + setNeedWallet.On("Budget").Return(btcutil.Amount(1)).Once() + normalSet.On("Inputs").Return(nil).Times(4) + normalSet.On("DeadlineHeight").Return(fn.None[int32]()).Once() + normalSet.On("Budget").Return(btcutil.Amount(1)).Once() // Make pending inputs for testing. We don't need real values here as // the returned clusters are mocked. @@ -2460,19 +2690,369 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet, normalSet, }) - // Set change output script to an invalid value. This should cause the + // Mock `Broadcast` to return an error. This should cause the // `createSweepTx` inside `sweep` to fail. This is done so we can // terminate the method early as we are only interested in testing the // workflow in `sweepPendingInputs`. We don't need to test `sweep` here // as it should be tested in its own unit test. - s.currentOutputScript = []byte{1} + dummyErr := errors.New("dummy error") + publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice() // Call the method under test. s.sweepPendingInputs(pis) +} + +// TestHandleBumpEventTxFailed checks that the sweeper correctly handles the +// case where the bump event tx fails to be published. +func TestHandleBumpEventTxFailed(t *testing.T) { + t.Parallel() + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + var ( + // Create four testing outpoints. + op1 = wire.OutPoint{Hash: chainhash.Hash{1}} + op2 = wire.OutPoint{Hash: chainhash.Hash{2}} + op3 = wire.OutPoint{Hash: chainhash.Hash{3}} + opNotExist = wire.OutPoint{Hash: chainhash.Hash{4}} + ) + + // Create three mock inputs. + input1 := &input.MockInput{} + defer input1.AssertExpectations(t) + + input2 := &input.MockInput{} + defer input2.AssertExpectations(t) - // Assert mocked methods are called as expected. - wallet.AssertExpectations(t) - aggregator.AssertExpectations(t) - setNeedWallet.AssertExpectations(t) - normalSet.AssertExpectations(t) + input3 := &input.MockInput{} + defer input3.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op1: &pendingInput{Input: input1, state: StatePendingPublish}, + op2: &pendingInput{Input: input2, state: StatePendingPublish}, + op3: &pendingInput{Input: input3, state: StatePendingPublish}, + } + + // Create a testing tx that spends the first two inputs. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op1}, + {PreviousOutPoint: op2}, + {PreviousOutPoint: opNotExist}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: tx, + Event: TxFailed, + Err: errDummy, + } + + // Call the method under test. + err := s.handleBumpEvent(br) + require.ErrorIs(t, err, errDummy) + + // Assert the states of the first two inputs are updated. + require.Equal(t, StatePublishFailed, s.pendingInputs[op1].state) + require.Equal(t, StatePublishFailed, s.pendingInputs[op2].state) + + // Assert the state of the third input is not updated. + require.Equal(t, StatePendingPublish, s.pendingInputs[op3].state) + + // Assert the non-existing input is not added to the pending inputs. + require.NotContains(t, s.pendingInputs, opNotExist) +} + +// TestHandleBumpEventTxReplaced checks that the sweeper correctly handles the +// case where the bump event tx is replaced. +func TestHandleBumpEventTxReplaced(t *testing.T) { + t.Parallel() + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a replacement tx. + replacementTx := &wire.MsgTx{ + LockTime: 2, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: replacementTx, + ReplacedTx: tx, + Event: TxReplaced, + } + + // Mock the store to return an error. + dummyErr := errors.New("dummy error") + store.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once() + + // Call the method under test and assert the error is returned. + err := s.handleBumpEventTxReplaced(br) + require.ErrorIs(t, err, dummyErr) + + // Mock the store to return the old tx record. + store.On("GetTx", tx.TxHash()).Return(&TxRecord{ + Txid: tx.TxHash(), + }, nil).Once() + + // Mock an error returned when deleting the old tx record. + store.On("DeleteTx", tx.TxHash()).Return(dummyErr).Once() + + // Call the method under test and assert the error is returned. + err = s.handleBumpEventTxReplaced(br) + require.ErrorIs(t, err, dummyErr) + + // Mock the store to return the old tx record and delete it without + // error. + store.On("GetTx", tx.TxHash()).Return(&TxRecord{ + Txid: tx.TxHash(), + }, nil).Once() + store.On("DeleteTx", tx.TxHash()).Return(nil).Once() + + // Mock the store to save the new tx record. + store.On("StoreTx", &TxRecord{ + Txid: replacementTx.TxHash(), + Published: true, + }).Return(nil).Once() + + // Call the method under test. + err = s.handleBumpEventTxReplaced(br) + require.NoError(t, err) + + // Assert the state of the input is updated. + require.Equal(t, StatePublished, s.pendingInputs[op].state) +} + +// TestHandleBumpEventTxPublished checks that the sweeper correctly handles the +// case where the bump event tx is published. +func TestHandleBumpEventTxPublished(t *testing.T) { + t.Parallel() + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: tx, + Event: TxPublished, + } + + // Mock the store to save the new tx record. + store.On("StoreTx", &TxRecord{ + Txid: tx.TxHash(), + Published: true, + }).Return(nil).Once() + + // Call the method under test. + err := s.handleBumpEventTxPublished(br) + require.NoError(t, err) + + // Assert the state of the input is updated. + require.Equal(t, StatePublished, s.pendingInputs[op].state) +} + +// TestMonitorFeeBumpResult checks that the fee bump monitor loop correctly +// exits when the sweeper is stopped, the tx is confirmed or failed. +func TestMonitorFeeBumpResult(t *testing.T) { + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + testCases := []struct { + name string + setupResultChan func() <-chan *BumpResult + shouldExit bool + }{ + { + // When a tx confirmed event is received, we expect to + // exit the monitor loop. + name: "tx confirmed", + // We send a result with TxConfirmed event to the + // result channel. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxConfirmed, + Fee: 10000, + FeeRate: 100, + } + + return resultChan + }, + shouldExit: true, + }, + { + // When a tx failed event is received, we expect to + // exit the monitor loop. + name: "tx failed", + // We send a result with TxConfirmed event to the + // result channel. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxFailed, + Err: errDummy, + } + + return resultChan + }, + shouldExit: true, + }, + { + // When processing non-confirmed events, the monitor + // should not exit. + name: "no exit on normal event", + // We send a result with TxPublished and mock the + // method `StoreTx` to return nil. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxPublished, + } + + return resultChan + }, + shouldExit: false, + }, { + // When the sweeper is shutting down, the monitor loop + // should exit. + name: "exit on sweeper shutdown", + // We don't send anything but quit the sweeper. + setupResultChan: func() <-chan *BumpResult { + close(s.quit) + + return nil + }, + shouldExit: true, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Setup the testing result channel. + resultChan := tc.setupResultChan() + + // Create a done chan that's used to signal the monitor + // has exited. + done := make(chan struct{}) + + s.wg.Add(1) + go func() { + s.monitorFeeBumpResult(resultChan) + close(done) + }() + + // The monitor is expected to exit, we check it's done + // in one second or fail. + if tc.shouldExit { + select { + case <-done: + case <-time.After(1 * time.Second): + require.Fail(t, "monitor not exited") + } + + return + } + + // The monitor should not exit, check it doesn't close + // the `done` channel within one second. + select { + case <-done: + require.Fail(t, "monitor exited") + case <-time.After(1 * time.Second): + } + }) + } } diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index b80ea2db09..d2e47c57b2 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -34,6 +35,14 @@ var ( // ErrNotEnoughInputs is returned when there are not enough wallet // inputs to construct a non-dust change output for an input set. ErrNotEnoughInputs = fmt.Errorf("not enough inputs") + + // ErrDeadlinesMismatch is returned when the deadlines of the input + // sets do not match. + ErrDeadlinesMismatch = fmt.Errorf("deadlines mismatch") + + // ErrDustOutput is returned when the output value is below the dust + // limit. + ErrDustOutput = fmt.Errorf("dust output") ) // InputSet defines an interface that's responsible for filtering a set of @@ -42,9 +51,6 @@ type InputSet interface { // Inputs returns the set of inputs that should be used to create a tx. Inputs() []input.Input - // FeeRate returns the fee rate that should be used for the tx. - FeeRate() chainfee.SatPerKWeight - // AddWalletInputs adds wallet inputs to the set until a non-dust // change output can be made. Return an error if there are not enough // wallet inputs. @@ -53,6 +59,24 @@ type InputSet interface { // NeedWalletInput returns true if the input set needs more wallet // inputs. NeedWalletInput() bool + + // DeadlineHeight returns an optional absolute block height to express + // the time-sensitivity of the input set. The outputs from a force + // close tx have different time preferences: + // - to_local: no time pressure as it can only be swept by us. + // - first level outgoing HTLC: must be swept before its corresponding + // incoming HTLC's CLTV is reached. + // - first level incoming HTLC: must be swept before its CLTV is + // reached. + // - second level HTLCs: no time pressure. + // - anchor: for CPFP-purpose anchor, it must be swept before any of + // the above CLTVs is reached. For non-CPFP purpose anchor, there's + // no time pressure. + DeadlineHeight() fn.Option[int32] + + // Budget givens the total amount that can be used as fees by this + // input set. + Budget() btcutil.Amount } type txInputSetState struct { @@ -167,9 +191,18 @@ func (t *txInputSet) Inputs() []input.Input { return t.inputs } -// FeeRate returns the fee rate that should be used for the tx. -func (t *txInputSet) FeeRate() chainfee.SatPerKWeight { - return t.feeRate +// Budget gives the total amount that can be used as fees by this input set. +// +// NOTE: this field is only used for `BudgetInputSet`. +func (t *txInputSet) Budget() btcutil.Amount { + return t.totalOutput() +} + +// DeadlineHeight gives the block height that this set must be confirmed by. +// +// NOTE: this field is only used for `BudgetInputSet`. +func (t *txInputSet) DeadlineHeight() fn.Option[int32] { + return fn.None[int32]() } // NeedWalletInput returns true if the input set needs more wallet inputs. @@ -509,3 +542,260 @@ func createWalletTxInput(utxo *lnwallet.Utxo) (input.Input, error) { &utxo.OutPoint, witnessType, signDesc, heightHint, ), nil } + +// BudgetInputSet implements the interface `InputSet`. It takes a list of +// pending inputs which share the same deadline height and groups them into a +// set conditionally based on their economical values. +type BudgetInputSet struct { + // inputs is the set of inputs that have been added to the set after + // considering their economical contribution. + inputs []*pendingInput + + // deadlineHeight is the height which the inputs in this set must be + // confirmed by. + deadlineHeight fn.Option[int32] +} + +// Compile-time constraint to ensure budgetInputSet implements InputSet. +var _ InputSet = (*BudgetInputSet)(nil) + +// validateInputs is used when creating new BudgetInputSet to ensure there are +// no duplicate inputs and they all share the same deadline heights, if set. +func validateInputs(inputs []pendingInput) error { + // Sanity check the input slice to ensure it's non-empty. + if len(inputs) == 0 { + return fmt.Errorf("inputs slice is empty") + } + + // dedupInputs is a map used to track unique outpoints of the inputs. + dedupInputs := make(map[*wire.OutPoint]struct{}) + + // deadlineSet stores unique deadline heights. + deadlineSet := make(map[fn.Option[int32]]struct{}) + + for _, input := range inputs { + input.params.DeadlineHeight.WhenSome(func(h int32) { + deadlineSet[input.params.DeadlineHeight] = struct{}{} + }) + + dedupInputs[input.OutPoint()] = struct{}{} + } + + // Make sure the inputs share the same deadline height when there is + // one. + if len(deadlineSet) > 1 { + return fmt.Errorf("inputs have different deadline heights") + } + + // Provide a defensive check to ensure that we don't have any duplicate + // inputs within the set. + if len(dedupInputs) != len(inputs) { + return fmt.Errorf("duplicate inputs") + } + + return nil +} + +// NewBudgetInputSet creates a new BudgetInputSet. +func NewBudgetInputSet(inputs []pendingInput) (*BudgetInputSet, error) { + // Validate the supplied inputs. + if err := validateInputs(inputs); err != nil { + return nil, err + } + + // TODO(yy): all the inputs share the same deadline height, which means + // there exists an opportunity to refactor the deadline height to be + // tracked on the set-level, not per input. This would allow us to + // avoid the overhead of tracking the same height for each input in the + // set. + deadlineHeight := inputs[0].params.DeadlineHeight + bi := &BudgetInputSet{ + deadlineHeight: deadlineHeight, + inputs: make([]*pendingInput, 0, len(inputs)), + } + + for _, input := range inputs { + bi.addInput(input) + } + + log.Tracef("Created %v", bi.String()) + + return bi, nil +} + +// String returns a human-readable description of the input set. +func (b *BudgetInputSet) String() string { + deadlineDesc := "none" + b.deadlineHeight.WhenSome(func(h int32) { + deadlineDesc = fmt.Sprintf("%d", h) + }) + + inputsDesc := "" + for _, input := range b.inputs { + inputsDesc += fmt.Sprintf("\n%v", input) + } + + return fmt.Sprintf("BudgetInputSet(budget=%v, deadline=%v, "+ + "inputs=[%v])", b.Budget(), deadlineDesc, inputsDesc) +} + +// addInput adds an input to the input set. +func (b *BudgetInputSet) addInput(input pendingInput) { + b.inputs = append(b.inputs, &input) +} + +// NeedWalletInput returns true if the input set needs more wallet inputs. +// +// A set may need wallet inputs when it has a required output or its total +// value cannot cover its total budget. +func (b *BudgetInputSet) NeedWalletInput() bool { + var ( + // budgetNeeded is the amount that needs to be covered from + // other inputs. + budgetNeeded btcutil.Amount + + // budgetBorrowable is the amount that can be borrowed from + // other inputs. + budgetBorrowable btcutil.Amount + ) + + for _, inp := range b.inputs { + // If this input has a required output, we can assume it's a + // second-level htlc txns input. Although this input must have + // a value that can cover its budget, it cannot be used to pay + // fees. Instead, we need to borrow budget from other inputs to + // make the sweep happen. Once swept, the input value will be + // credited to the wallet. + if inp.RequiredTxOut() != nil { + budgetNeeded += inp.params.Budget + continue + } + + // Get the amount left after covering the input's own budget. + // This amount can then be lent to the above input. + budget := inp.params.Budget + output := btcutil.Amount(inp.SignDesc().Output.Value) + budgetBorrowable += output - budget + + // If the input's budget is not even covered by itself, we need + // to borrow outputs from other inputs. + if budgetBorrowable < 0 { + log.Debugf("Input %v specified a budget that exceeds "+ + "its output value: %v > %v", inp, budget, + output) + } + } + + log.Tracef("NeedWalletInput: budgetNeeded=%v, budgetBorrowable=%v", + budgetNeeded, budgetBorrowable) + + // If we don't have enough extra budget to borrow, we need wallet + // inputs. + return budgetBorrowable < budgetNeeded +} + +// copyInputs returns a copy of the slice of the inputs in the set. +func (b *BudgetInputSet) copyInputs() []*pendingInput { + inputs := make([]*pendingInput, len(b.inputs)) + copy(inputs, b.inputs) + return inputs +} + +// AddWalletInputs adds wallet inputs to the set until the specified budget is +// met. When sweeping inputs with required outputs, although there's budget +// specified, it cannot be directly spent from these required outputs. Instead, +// we need to borrow budget from other inputs to make the sweep happen. +// There are two sources to borrow from: 1) other inputs, 2) wallet utxos. If +// we are calling this method, it means other inputs cannot cover the specified +// budget, so we need to borrow from wallet utxos. +// +// Return an error if there are not enough wallet inputs, and the budget set is +// set to its initial state by removing any wallet inputs added. +// +// NOTE: must be called with the wallet lock held via `WithCoinSelectLock`. +func (b *BudgetInputSet) AddWalletInputs(wallet Wallet) error { + // Retrieve wallet utxos. Only consider confirmed utxos to prevent + // problems around RBF rules for unconfirmed inputs. This currently + // ignores the configured coin selection strategy. + utxos, err := wallet.ListUnspentWitnessFromDefaultAccount( + 1, math.MaxInt32, + ) + if err != nil { + return fmt.Errorf("list unspent witness: %w", err) + } + + // Sort the UTXOs by putting smaller values at the start of the slice + // to avoid locking large UTXO for sweeping. + // + // TODO(yy): add more choices to CoinSelectionStrategy and use the + // configured value here. + sort.Slice(utxos, func(i, j int) bool { + return utxos[i].Value < utxos[j].Value + }) + + // Make a copy of the current inputs. If the wallet doesn't have enough + // utxos to cover the budget, we will revert the current set to its + // original state by removing the added wallet inputs. + originalInputs := b.copyInputs() + + // Add wallet inputs to the set until the specified budget is covered. + for _, utxo := range utxos { + input, err := createWalletTxInput(utxo) + if err != nil { + return err + } + + pi := pendingInput{ + Input: input, + params: Params{ + // Inherit the deadline height from the input + // set. + DeadlineHeight: b.deadlineHeight, + }, + } + + b.addInput(pi) + + // Return if we've reached the minimum output amount. + if !b.NeedWalletInput() { + return nil + } + } + + // The wallet doesn't have enough utxos to cover the budget. Revert the + // input set to its original state. + b.inputs = originalInputs + + return ErrNotEnoughInputs +} + +// Budget returns the total budget of the set. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) Budget() btcutil.Amount { + budget := btcutil.Amount(0) + for _, input := range b.inputs { + budget += input.params.Budget + } + + return budget +} + +// DeadlineHeight returns the deadline height of the set. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) DeadlineHeight() fn.Option[int32] { + return b.deadlineHeight +} + +// Inputs returns the inputs that should be used to create a tx. +// +// NOTE: part of the InputSet interface. +func (b *BudgetInputSet) Inputs() []input.Input { + inputs := make([]input.Input, 0, len(b.inputs)) + for _, inp := range b.inputs { + inputs = append(inputs, inp.Input) + } + + return inputs +} diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go index 51afff7b77..32a08fba4d 100644 --- a/sweep/tx_input_set_test.go +++ b/sweep/tx_input_set_test.go @@ -1,10 +1,14 @@ package sweep import ( + "errors" + "math" "testing" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" @@ -237,3 +241,432 @@ func TestTxInputSetRequiredOutput(t *testing.T) { } require.True(t, set.enoughInput()) } + +// TestNewBudgetInputSet checks `NewBudgetInputSet` correctly validates the +// supplied inputs and returns the error. +func TestNewBudgetInputSet(t *testing.T) { + t.Parallel() + + rt := require.New(t) + + // Pass an empty slice and expect an error. + set, err := NewBudgetInputSet([]pendingInput{}) + rt.ErrorContains(err, "inputs slice is empty") + rt.Nil(set) + + // Create two inputs with different deadline heights. + inp0 := createP2WKHInput(1000) + inp1 := createP2WKHInput(1000) + inp2 := createP2WKHInput(1000) + input0 := pendingInput{ + Input: inp0, + params: Params{ + Budget: 100, + DeadlineHeight: fn.None[int32](), + }, + } + input1 := pendingInput{ + Input: inp1, + params: Params{ + Budget: 100, + DeadlineHeight: fn.Some(int32(1)), + }, + } + input2 := pendingInput{ + Input: inp2, + params: Params{ + Budget: 100, + DeadlineHeight: fn.Some(int32(2)), + }, + } + + // Pass a slice of inputs with different deadline heights. + set, err = NewBudgetInputSet([]pendingInput{input1, input2}) + rt.ErrorContains(err, "inputs have different deadline heights") + rt.Nil(set) + + // Pass a slice of inputs that only one input has the deadline height. + set, err = NewBudgetInputSet([]pendingInput{input0, input2}) + rt.NoError(err) + rt.NotNil(set) + + // Pass a slice of inputs that are duplicates. + set, err = NewBudgetInputSet([]pendingInput{input1, input1}) + rt.ErrorContains(err, "duplicate inputs") + rt.Nil(set) +} + +// TestBudgetInputSetAddInput checks that `addInput` correctly updates the +// budget of the input set. +func TestBudgetInputSetAddInput(t *testing.T) { + t.Parallel() + + // Create a testing input with a budget of 100 satoshis. + input := createP2WKHInput(1000) + pi := &pendingInput{ + Input: input, + params: Params{ + Budget: 100, + }, + } + + // Initialize an input set, which adds the above input. + set, err := NewBudgetInputSet([]pendingInput{*pi}) + require.NoError(t, err) + + // Add the input to the set again. + set.addInput(*pi) + + // The set should now have two inputs. + require.Len(t, set.inputs, 2) + require.Equal(t, pi, set.inputs[0]) + require.Equal(t, pi, set.inputs[1]) + + // The set should have a budget of 200 satoshis. + require.Equal(t, btcutil.Amount(200), set.Budget()) +} + +// TestNeedWalletInput checks that NeedWalletInput correctly determines if a +// wallet input is needed. +func TestNeedWalletInput(t *testing.T) { + t.Parallel() + + // Create a mock input that doesn't have required outputs. + mockInput := &input.MockInput{} + mockInput.On("RequiredTxOut").Return(nil) + defer mockInput.AssertExpectations(t) + + // Create a mock input that has required outputs. + mockInputRequireOutput := &input.MockInput{} + mockInputRequireOutput.On("RequiredTxOut").Return(&wire.TxOut{}) + defer mockInputRequireOutput.AssertExpectations(t) + + // We now create two pending inputs each has a budget of 100 satoshis. + const budget = 100 + + // Create the pending input that doesn't have a required output. + piBudget := &pendingInput{ + Input: mockInput, + params: Params{Budget: budget}, + } + + // Create the pending input that has a required output. + piRequireOutput := &pendingInput{ + Input: mockInputRequireOutput, + params: Params{Budget: budget}, + } + + testCases := []struct { + name string + setupInputs func() []*pendingInput + need bool + }{ + { + // When there are no pending inputs, we won't need a + // wallet input. Technically this should be an invalid + // state. + name: "no inputs", + setupInputs: func() []*pendingInput { + return nil + }, + need: false, + }, + { + // When there's no required output, we don't need a + // wallet input. + name: "no required outputs", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + + return []*pendingInput{piBudget} + }, + need: false, + }, + { + // When the output value cannot cover the budget, we + // need a wallet input. + name: "output value cannot cover budget", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget - 1, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + + // These two methods are only invoked when the + // unit test is running with a logger. + mockInput.On("OutPoint").Return( + &wire.OutPoint{Hash: chainhash.Hash{1}}, + ).Maybe() + mockInput.On("WitnessType").Return( + input.CommitmentAnchor, + ).Maybe() + + return []*pendingInput{piBudget} + }, + need: true, + }, + { + // When there's only inputs that require outputs, we + // need wallet inputs. + name: "only required outputs", + setupInputs: func() []*pendingInput { + return []*pendingInput{piRequireOutput} + }, + need: true, + }, + { + // When there's a mix of inputs, but the borrowable + // budget cannot cover the required, we need a wallet + // input. + name: "not enough budget to be borrowed", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + // + // NOTE: the value is exactly the same as the + // budget so we can't borrow any more. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + + return []*pendingInput{ + piBudget, piRequireOutput, + } + }, + need: true, + }, + { + // When there's a mix of inputs, and the budget can be + // borrowed covers the required, we don't need wallet + // inputs. + name: "enough budget to be borrowed", + setupInputs: func() []*pendingInput { + // Create a sign descriptor to be used in the + // pending input when calculating budgets can + // be borrowed. + // + // NOTE: the value is exactly the same as the + // budget so we can't borrow any more. + sd := &input.SignDescriptor{ + Output: &wire.TxOut{ + Value: budget * 2, + }, + } + mockInput.On("SignDesc").Return(sd).Once() + piBudget.Input = mockInput + + return []*pendingInput{ + piBudget, piRequireOutput, + } + }, + need: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup testing inputs. + inputs := tc.setupInputs() + + // Initialize an input set, which adds the testing + // inputs. + set := &BudgetInputSet{inputs: inputs} + + result := set.NeedWalletInput() + require.Equal(t, tc.need, result) + }) + } +} + +// TestAddWalletInputReturnErr tests the three possible errors returned from +// AddWalletInputs: +// - error from ListUnspentWitnessFromDefaultAccount. +// - error from createWalletTxInput. +// - error when wallet doesn't have utxos. +func TestAddWalletInputReturnErr(t *testing.T) { + t.Parallel() + + wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + + // Initialize an empty input set. + set := &BudgetInputSet{} + + // Specify the min and max confs used in + // ListUnspentWitnessFromDefaultAccount. + min, max := int32(1), int32(math.MaxInt32) + + // Mock the wallet to return an error. + dummyErr := errors.New("dummy error") + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return(nil, dummyErr).Once() + + // Check that the error is returned from + // ListUnspentWitnessFromDefaultAccount. + err := set.AddWalletInputs(wallet) + require.ErrorIs(t, err, dummyErr) + + // Create an utxo with unknown address type to trigger an error. + utxo := &lnwallet.Utxo{ + AddressType: lnwallet.UnknownAddressType, + } + + // Mock the wallet to return the above utxo. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{utxo}, nil).Once() + + // Check that the error is returned from createWalletTxInput. + err = set.AddWalletInputs(wallet) + require.Error(t, err) + + // Mock the wallet to return empty utxos. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{}, nil).Once() + + // Check that the error is returned from not having wallet inputs. + err = set.AddWalletInputs(wallet) + require.ErrorIs(t, err, ErrNotEnoughInputs) +} + +// TestAddWalletInputNotEnoughInputs checks that when there are not enough +// wallet utxos, an error is returned and the budget set is reset to its +// initial state. +func TestAddWalletInputNotEnoughInputs(t *testing.T) { + t.Parallel() + + wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + + // Specify the min and max confs used in + // ListUnspentWitnessFromDefaultAccount. + min, max := int32(1), int32(math.MaxInt32) + + // Assume the desired budget is 10k satoshis. + const budget = 10_000 + + // Create a mock input that has required outputs. + mockInput := &input.MockInput{} + mockInput.On("RequiredTxOut").Return(&wire.TxOut{}) + defer mockInput.AssertExpectations(t) + + // Create a pending input that requires 10k satoshis. + pi := &pendingInput{ + Input: mockInput, + params: Params{Budget: budget}, + } + + // Create a wallet utxo that cannot cover the budget. + utxo := &lnwallet.Utxo{ + AddressType: lnwallet.WitnessPubKey, + Value: budget - 1, + } + + // Mock the wallet to return the above utxo. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{utxo}, nil).Once() + + // Initialize an input set with the pending input. + set := BudgetInputSet{inputs: []*pendingInput{pi}} + + // Add wallet inputs to the input set, which should give us an error as + // the wallet cannot cover the budget. + err := set.AddWalletInputs(wallet) + require.ErrorIs(t, err, ErrNotEnoughInputs) + + // Check that the budget set is reverted to its initial state. + require.Len(t, set.inputs, 1) + require.Equal(t, pi, set.inputs[0]) +} + +// TestAddWalletInputSuccess checks that when there are enough wallet utxos, +// they are added to the input set. +func TestAddWalletInputSuccess(t *testing.T) { + t.Parallel() + + wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + + // Specify the min and max confs used in + // ListUnspentWitnessFromDefaultAccount. + min, max := int32(1), int32(math.MaxInt32) + + // Assume the desired budget is 10k satoshis. + const budget = 10_000 + + // Create a mock input that has required outputs. + mockInput := &input.MockInput{} + mockInput.On("RequiredTxOut").Return(&wire.TxOut{}) + defer mockInput.AssertExpectations(t) + + // Create a pending input that requires 10k satoshis. + deadline := int32(1000) + pi := &pendingInput{ + Input: mockInput, + params: Params{ + Budget: budget, + DeadlineHeight: fn.Some(deadline), + }, + } + + // Mock methods used in loggings. + // + // NOTE: these methods are not functional as they are only used for + // loggings in debug or trace mode so we use arbitrary values. + mockInput.On("OutPoint").Return(&wire.OutPoint{Hash: chainhash.Hash{1}}) + mockInput.On("WitnessType").Return(input.CommitmentAnchor) + + // Create a wallet utxo that cannot cover the budget. + utxo := &lnwallet.Utxo{ + AddressType: lnwallet.WitnessPubKey, + Value: budget - 1, + } + + // Mock the wallet to return the two utxos which can cover the budget. + wallet.On("ListUnspentWitnessFromDefaultAccount", + min, max).Return([]*lnwallet.Utxo{utxo, utxo}, nil).Once() + + // Initialize an input set with the pending input. + set, err := NewBudgetInputSet([]pendingInput{*pi}) + require.NoError(t, err) + + // Add wallet inputs to the input set, which should give us an error as + // the wallet cannot cover the budget. + err = set.AddWalletInputs(wallet) + require.NoError(t, err) + + // Check that the budget set is updated. + require.Len(t, set.inputs, 3) + + // The first input is the pending input. + require.Equal(t, pi, set.inputs[0]) + + // The second and third inputs are wallet inputs that have + // DeadlineHeight set. + input2Deadline := set.inputs[1].params.DeadlineHeight + require.Equal(t, deadline, input2Deadline.UnsafeFromSome()) + input3Deadline := set.inputs[2].params.DeadlineHeight + require.Equal(t, deadline, input3Deadline.UnsafeFromSome()) + + // Finally, check the interface methods. + require.EqualValues(t, budget, set.Budget()) + require.Equal(t, deadline, set.DeadlineHeight().UnsafeFromSome()) + // Weak check, a strong check is to open the slice and check each item. + require.Len(t, set.inputs, 3) +} diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index 0cab9a6e22..57e0cce17b 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -205,15 +205,14 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, } } - log.Infof("Creating sweep transaction %v for %v inputs (%s) "+ - "using %v sat/kw, tx_weight=%v, tx_fee=%v, parents_count=%v, "+ - "parents_fee=%v, parents_weight=%v", + log.Debugf("Creating sweep transaction %v for %v inputs (%s) "+ + "using %v, tx_weight=%v, tx_fee=%v, parents_count=%v, "+ + "parents_fee=%v, parents_weight=%v, current_height=%v", sweepTx.TxHash(), len(inputs), inputTypeSummary(inputs), feeRate, estimator.weight(), txFee, len(estimator.parents), estimator.parentsFee, - estimator.parentsWeight, - ) + estimator.parentsWeight, currentBlockHeight) return sweepTx, txFee, nil }