diff --git a/contractcourt/anchor_resolver.go b/contractcourt/anchor_resolver.go index d969600268..6aa25cd228 100644 --- a/contractcourt/anchor_resolver.go +++ b/contractcourt/anchor_resolver.go @@ -115,7 +115,7 @@ func (c *anchorResolver) Resolve() (ContractResolver, error) { resultChan, err := c.Sweeper.SweepInput( &anchorInput, sweep.Params{ - Fee: sweep.FeePreference{ + Fee: sweep.FeeEstimateInfo{ FeeRate: relayFeeRate, }, }, diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index e6f91d3cd0..dd29ab41f0 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -1359,7 +1359,7 @@ func (c *ChannelArbitrator) sweepAnchors(anchors *lnwallet.AnchorResolutions, _, err = c.cfg.Sweeper.SweepInput( &anchorInput, sweep.Params{ - Fee: sweep.FeePreference{ + Fee: sweep.FeeEstimateInfo{ ConfTarget: deadline, }, Force: force, diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index cd59f9654a..195d423c9d 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -351,7 +351,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) { // sweeper. c.log.Infof("sweeping commit output") - feePref := sweep.FeePreference{ConfTarget: commitOutputConfTarget} + feePref := sweep.FeeEstimateInfo{ConfTarget: commitOutputConfTarget} resultChan, err := c.Sweeper.SweepInput(inp, sweep.Params{Fee: feePref}) if err != nil { c.log.Errorf("unable to sweep input: %v", err) diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 0583ce8ead..e864fb6084 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -1,6 +1,7 @@ package contractcourt import ( + "fmt" "testing" "time" @@ -127,9 +128,15 @@ func (s *mockSweeper) SweepInput(input input.Input, params sweep.Params) ( s.sweptInputs <- input + // TODO(yy): use `mock.Mock` to avoid the conversion. + fee, ok := params.Fee.(sweep.FeeEstimateInfo) + if !ok { + return nil, fmt.Errorf("unexpected fee type: %T", params.Fee) + } + // Update the deadlines used if it's set. - if params.Fee.ConfTarget != 0 { - s.deadlines = append(s.deadlines, int(params.Fee.ConfTarget)) + if fee.ConfTarget != 0 { + s.deadlines = append(s.deadlines, int(fee.ConfTarget)) } result := make(chan sweep.Result, 1) @@ -140,8 +147,8 @@ func (s *mockSweeper) SweepInput(input input.Input, params sweep.Params) ( return result, nil } -func (s *mockSweeper) CreateSweepTx(inputs []input.Input, feePref sweep.FeePreference, - currentBlockHeight uint32) (*wire.MsgTx, error) { +func (s *mockSweeper) CreateSweepTx(inputs []input.Input, + feePref sweep.FeeEstimateInfo) (*wire.MsgTx, error) { // We will wait for the test to supply the sweep tx to return. sweepTx := <-s.createSweepTxChan diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 545a70f9fb..d37ed012f9 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -263,7 +263,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() ( _, err := h.Sweeper.SweepInput( &secondLevelInput, sweep.Params{ - Fee: sweep.FeePreference{ + Fee: sweep.FeeEstimateInfo{ ConfTarget: secondLevelConfTarget, }, }, @@ -375,7 +375,7 @@ func (h *htlcSuccessResolver) broadcastReSignedSuccessTx() ( _, err = h.Sweeper.SweepInput( inp, sweep.Params{ - Fee: sweep.FeePreference{ + Fee: sweep.FeeEstimateInfo{ ConfTarget: sweepConfTarget, }, }, @@ -432,17 +432,13 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput() ( // transaction, that we'll use to move these coins back into // the backing wallet. // - // TODO: Set tx lock time to current block height instead of - // zero. Will be taken care of once sweeper implementation is - // complete. - // // TODO: Use time-based sweeper and result chan. var err error h.sweepTx, err = h.Sweeper.CreateSweepTx( []input.Input{inp}, - sweep.FeePreference{ + sweep.FeeEstimateInfo{ ConfTarget: sweepConfTarget, - }, 0, + }, ) if err != nil { return nil, err diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 8adcb63b3b..63bfa58fc8 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -486,7 +486,7 @@ func (h *htlcTimeoutResolver) sweepSecondLevelTx() error { _, err := h.Sweeper.SweepInput( inp, sweep.Params{ - Fee: sweep.FeePreference{ + Fee: sweep.FeeEstimateInfo{ ConfTarget: secondLevelConfTarget, }, Force: true, @@ -702,7 +702,7 @@ func (h *htlcTimeoutResolver) handleCommitSpend( _, err = h.Sweeper.SweepInput( inp, sweep.Params{ - Fee: sweep.FeePreference{ + Fee: sweep.FeeEstimateInfo{ ConfTarget: sweepConfTarget, }, }, diff --git a/contractcourt/interfaces.go b/contractcourt/interfaces.go index 45cd75735a..b4fdf35291 100644 --- a/contractcourt/interfaces.go +++ b/contractcourt/interfaces.go @@ -52,8 +52,8 @@ type UtxoSweeper interface { // CreateSweepTx accepts a list of inputs and signs and generates a txn // that spends from them. This method also makes an accurate fee // estimate before generating the required witnesses. - CreateSweepTx(inputs []input.Input, feePref sweep.FeePreference, - currentBlockHeight uint32) (*wire.MsgTx, error) + CreateSweepTx(inputs []input.Input, + feePref sweep.FeeEstimateInfo) (*wire.MsgTx, error) // RelayFeePerKW returns the minimum fee rate required for transactions // to be relayed. diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index 57a3709e96..6b8742255b 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -823,7 +823,7 @@ func (u *UtxoNursery) sweepMatureOutputs(classHeight uint32, utxnLog.Infof("Sweeping %v CSV-delayed outputs with sweep tx for "+ "height %v", len(kgtnOutputs), classHeight) - feePref := sweep.FeePreference{ConfTarget: kgtnOutputConfTarget} + feePref := sweep.FeeEstimateInfo{ConfTarget: kgtnOutputConfTarget} for _, output := range kgtnOutputs { // Create local copy to prevent pointer to loop variable to be // passed in with disastrous consequences. diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 86e1576f65..2b48be116e 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -251,6 +251,14 @@ * [Allow callers of ListSweeps to specify the start height]( https://github.com/lightningnetwork/lnd/pull/7372). +* Previously when callng `SendCoins`, `SendMany`, `OpenChannel` and + `CloseChannel` for coop close, it is allowed to specify both an empty + `SatPerVbyte` and `TargetConf`, and a default conf target of 6 will be used. + This is [no longer allowed]( + https://github.com/lightningnetwork/lnd/pull/8422) and the caller must + specify either `SatPerVbyte` or `TargetConf` so the fee estimator can do a + proper fee estimation. + ## lncli Updates * [Documented all available lncli commands](https://github.com/lightningnetwork/lnd/pull/8181). diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 6770c949f1..1d15a39d02 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -81,6 +81,7 @@ func (m *mockPreimageCache) SubscribeUpdates( return nil, nil } +// TODO(yy): replace it with chainfee.MockEstimator. type mockFeeEstimator struct { byteFeeIn chan chainfee.SatPerKWeight relayFee chan chainfee.SatPerKWeight diff --git a/itest/lnd_channel_funding_fund_max_test.go b/itest/lnd_channel_funding_fund_max_test.go index 4a063bef13..63d4d471df 100644 --- a/itest/lnd_channel_funding_fund_max_test.go +++ b/itest/lnd_channel_funding_fund_max_test.go @@ -322,8 +322,9 @@ func sweepNodeWalletAndAssert(ht *lntest.HarnessTest, node *node.HarnessNode) { // Send all funds back to the miner node. node.RPC.SendCoins(&lnrpc.SendCoinsRequest{ - Addr: minerAddr.String(), - SendAll: true, + Addr: minerAddr.String(), + SendAll: true, + TargetConf: 6, }) // Ensures we don't leave any transaction in the mempool after sweeping. diff --git a/itest/lnd_coop_close_with_htlcs_test.go b/itest/lnd_coop_close_with_htlcs_test.go index 4c437cd32c..50f1a3401d 100644 --- a/itest/lnd_coop_close_with_htlcs_test.go +++ b/itest/lnd_coop_close_with_htlcs_test.go @@ -85,6 +85,7 @@ func coopCloseWithHTLCs(ht *lntest.HarnessTest) { closeClient := alice.RPC.CloseChannel(&lnrpc.CloseChannelRequest{ ChannelPoint: chanPoint, NoWait: true, + TargetConf: 6, }) ht.AssertChannelInactive(bob, chanPoint) @@ -184,6 +185,7 @@ func coopCloseWithHTLCsWithRestart(ht *lntest.HarnessTest) { ChannelPoint: chanPoint, NoWait: true, DeliveryAddress: newAddr.Address, + TargetConf: 6, }) // Assert that both nodes see the channel as waiting for close. diff --git a/itest/lnd_funding_test.go b/itest/lnd_funding_test.go index c519a360ba..97613429eb 100644 --- a/itest/lnd_funding_test.go +++ b/itest/lnd_funding_test.go @@ -528,8 +528,9 @@ func sendAllCoinsConfirm(ht *lntest.HarnessTest, node *node.HarnessNode, addr string) { sweepReq := &lnrpc.SendCoinsRequest{ - Addr: addr, - SendAll: true, + Addr: addr, + SendAll: true, + TargetConf: 6, } node.RPC.SendCoins(sweepReq) ht.MineBlocksAndAssertNumTxes(1, 1) diff --git a/itest/lnd_misc_test.go b/itest/lnd_misc_test.go index a88abb7b32..e0caee2860 100644 --- a/itest/lnd_misc_test.go +++ b/itest/lnd_misc_test.go @@ -777,16 +777,18 @@ func testSweepAllCoins(ht *lntest.HarnessTest) { // Ensure that we can't send coins to our own Pubkey. ainz.RPC.SendCoinsAssertErr(&lnrpc.SendCoinsRequest{ - Addr: ainz.RPC.GetInfo().IdentityPubkey, - SendAll: true, - Label: sendCoinsLabel, + Addr: ainz.RPC.GetInfo().IdentityPubkey, + SendAll: true, + Label: sendCoinsLabel, + TargetConf: 6, }) // Ensure that we can't send coins to another user's Pubkey. ainz.RPC.SendCoinsAssertErr(&lnrpc.SendCoinsRequest{ - Addr: ht.Alice.RPC.GetInfo().IdentityPubkey, - SendAll: true, - Label: sendCoinsLabel, + Addr: ht.Alice.RPC.GetInfo().IdentityPubkey, + SendAll: true, + Label: sendCoinsLabel, + TargetConf: 6, }) // With the two coins above mined, we'll now instruct Ainz to sweep all @@ -798,23 +800,34 @@ func testSweepAllCoins(ht *lntest.HarnessTest) { // Send coins to a testnet3 address. ainz.RPC.SendCoinsAssertErr(&lnrpc.SendCoinsRequest{ - Addr: "tb1qfc8fusa98jx8uvnhzavxccqlzvg749tvjw82tg", - SendAll: true, - Label: sendCoinsLabel, + Addr: "tb1qfc8fusa98jx8uvnhzavxccqlzvg749tvjw82tg", + SendAll: true, + Label: sendCoinsLabel, + TargetConf: 6, }) // Send coins to a mainnet address. ainz.RPC.SendCoinsAssertErr(&lnrpc.SendCoinsRequest{ - Addr: "1MPaXKp5HhsLNjVSqaL7fChE3TVyrTMRT3", + Addr: "1MPaXKp5HhsLNjVSqaL7fChE3TVyrTMRT3", + SendAll: true, + Label: sendCoinsLabel, + TargetConf: 6, + }) + + // Send coins to a compatible address without specifying fee rate or + // conf target. + ainz.RPC.SendCoinsAssertErr(&lnrpc.SendCoinsRequest{ + Addr: ht.Miner.NewMinerAddress().String(), SendAll: true, Label: sendCoinsLabel, }) // Send coins to a compatible address. ainz.RPC.SendCoins(&lnrpc.SendCoinsRequest{ - Addr: ht.Miner.NewMinerAddress().String(), - SendAll: true, - Label: sendCoinsLabel, + Addr: ht.Miner.NewMinerAddress().String(), + SendAll: true, + Label: sendCoinsLabel, + TargetConf: 6, }) // We'll mine a block which should include the sweep transaction we @@ -911,10 +924,11 @@ func testSweepAllCoins(ht *lntest.HarnessTest) { // If we try again, but this time specifying an amount, then the call // should fail. ainz.RPC.SendCoinsAssertErr(&lnrpc.SendCoinsRequest{ - Addr: ht.Miner.NewMinerAddress().String(), - Amount: 10000, - SendAll: true, - Label: sendCoinsLabel, + Addr: ht.Miner.NewMinerAddress().String(), + Amount: 10000, + SendAll: true, + Label: sendCoinsLabel, + TargetConf: 6, }) // With all the edge cases tested, we'll now test the happy paths of @@ -940,8 +954,9 @@ func testSweepAllCoins(ht *lntest.HarnessTest) { // Let's send some coins to the main address. const amt = 123456 resp := ainz.RPC.SendCoins(&lnrpc.SendCoinsRequest{ - Addr: mainAddrResp.Address, - Amount: amt, + Addr: mainAddrResp.Address, + Amount: amt, + TargetConf: 6, }) block := ht.MineBlocksAndAssertNumTxes(1, 1)[0] sweepTx := block.Transactions[1] @@ -1024,6 +1039,7 @@ func testListAddresses(ht *lntest.HarnessTest) { Addr: addr, Amount: addressDetail.Balance, SpendUnconfirmed: true, + TargetConf: 6, }) } diff --git a/itest/lnd_onchain_test.go b/itest/lnd_onchain_test.go index 5b8080387c..9eac66cc32 100644 --- a/itest/lnd_onchain_test.go +++ b/itest/lnd_onchain_test.go @@ -240,8 +240,9 @@ func runCPFP(ht *lntest.HarnessTest, alice, bob *node.HarnessNode) { // Send the coins from Alice to Bob. We should expect a transaction to // be broadcast and seen in the mempool. sendReq := &lnrpc.SendCoinsRequest{ - Addr: resp.Address, - Amount: btcutil.SatoshiPerBitcoin, + Addr: resp.Address, + Amount: btcutil.SatoshiPerBitcoin, + TargetConf: 6, } alice.RPC.SendCoins(sendReq) txid := ht.Miner.AssertNumTxsInMempool(1)[0] @@ -383,8 +384,9 @@ func testAnchorReservedValue(ht *lntest.HarnessTest) { resp := alice.RPC.NewAddress(req) sweepReq := &lnrpc.SendCoinsRequest{ - Addr: resp.Address, - SendAll: true, + Addr: resp.Address, + SendAll: true, + TargetConf: 6, } alice.RPC.SendCoins(sweepReq) @@ -432,8 +434,9 @@ func testAnchorReservedValue(ht *lntest.HarnessTest) { minerAddr := ht.Miner.NewMinerAddress() sweepReq = &lnrpc.SendCoinsRequest{ - Addr: minerAddr.String(), - SendAll: true, + Addr: minerAddr.String(), + SendAll: true, + TargetConf: 6, } alice.RPC.SendCoins(sweepReq) @@ -469,8 +472,9 @@ func testAnchorReservedValue(ht *lntest.HarnessTest) { // We'll wait for the balance to reflect that the channel has been // closed and the funds are in the wallet. sweepReq = &lnrpc.SendCoinsRequest{ - Addr: minerAddr.String(), - SendAll: true, + Addr: minerAddr.String(), + SendAll: true, + TargetConf: 6, } alice.RPC.SendCoins(sweepReq) @@ -602,6 +606,7 @@ func testAnchorThirdPartySpend(ht *lntest.HarnessTest) { sweepReq := &lnrpc.SendCoinsRequest{ Addr: minerAddr.String(), SendAll: true, + TargetConf: 6, MinConfs: 0, SpendUnconfirmed: true, } @@ -755,8 +760,9 @@ func testRemoveTx(ht *lntest.HarnessTest) { // We send half the amount to that address generating two unconfirmed // outpoints in our internal wallet. sendReq := &lnrpc.SendCoinsRequest{ - Addr: resp.Address, - Amount: initialWalletAmt / 2, + Addr: resp.Address, + Amount: initialWalletAmt / 2, + TargetConf: 6, } alice.RPC.SendCoins(sendReq) txID := ht.Miner.AssertNumTxsInMempool(1)[0] diff --git a/itest/lnd_psbt_test.go b/itest/lnd_psbt_test.go index d70d5a8dec..a3b5f757b9 100644 --- a/itest/lnd_psbt_test.go +++ b/itest/lnd_psbt_test.go @@ -1539,6 +1539,7 @@ func sendAllCoinsToAddrType(ht *lntest.HarnessTest, Addr: resp.Address, SendAll: true, SpendUnconfirmed: true, + TargetConf: 6, }) ht.MineBlocksAndAssertNumTxes(1, 1) diff --git a/itest/lnd_recovery_test.go b/itest/lnd_recovery_test.go index ef251f0428..c3e6efccd3 100644 --- a/itest/lnd_recovery_test.go +++ b/itest/lnd_recovery_test.go @@ -254,8 +254,9 @@ func testOnchainFundRecovery(ht *lntest.HarnessTest) { minerAddr := ht.Miner.NewMinerAddress() req := &lnrpc.SendCoinsRequest{ - Addr: minerAddr.String(), - Amount: minerAmt, + Addr: minerAddr.String(), + Amount: minerAmt, + TargetConf: 6, } resp := node.RPC.SendCoins(req) diff --git a/itest/lnd_signer_test.go b/itest/lnd_signer_test.go index 52b42a29a9..23773eb71d 100644 --- a/itest/lnd_signer_test.go +++ b/itest/lnd_signer_test.go @@ -289,8 +289,9 @@ func assertSignOutputRaw(ht *lntest.HarnessTest, // Send some coins to the generated p2wpkh address. req := &lnrpc.SendCoinsRequest{ - Addr: targetAddr.String(), - Amount: 800_000, + Addr: targetAddr.String(), + Amount: 800_000, + TargetConf: 6, } alice.RPC.SendCoins(req) diff --git a/itest/lnd_taproot_test.go b/itest/lnd_taproot_test.go index c03c1b5e02..ed37e04e8f 100644 --- a/itest/lnd_taproot_test.go +++ b/itest/lnd_taproot_test.go @@ -101,8 +101,9 @@ func testTaprootSendCoinsKeySpendBip86(ht *lntest.HarnessTest, // Send the coins from Alice's wallet to her own, but to the new p2tr // address. alice.RPC.SendCoins(&lnrpc.SendCoinsRequest{ - Addr: p2trResp.Address, - Amount: 0.5 * btcutil.SatoshiPerBitcoin, + Addr: p2trResp.Address, + Amount: 0.5 * btcutil.SatoshiPerBitcoin, + TargetConf: 6, }) txid := ht.Miner.AssertNumTxsInMempool(1)[0] @@ -125,8 +126,9 @@ func testTaprootSendCoinsKeySpendBip86(ht *lntest.HarnessTest, }) alice.RPC.SendCoins(&lnrpc.SendCoinsRequest{ - Addr: p2trResp.Address, - SendAll: true, + Addr: p2trResp.Address, + SendAll: true, + TargetConf: 6, }) // Make sure the coins sent to the address are confirmed correctly, @@ -152,8 +154,9 @@ func testTaprootComputeInputScriptKeySpendBip86(ht *lntest.HarnessTest, // Send the coins from Alice's wallet to her own, but to the new p2tr // address. req := &lnrpc.SendCoinsRequest{ - Addr: p2trAddr.String(), - Amount: testAmount, + Addr: p2trAddr.String(), + Amount: testAmount, + TargetConf: 6, } alice.RPC.SendCoins(req) @@ -1469,8 +1472,9 @@ func sendToTaprootOutput(ht *lntest.HarnessTest, hn *node.HarnessNode, // Send some coins to the generated tapscript address. req := &lnrpc.SendCoinsRequest{ - Addr: tapScriptAddr.String(), - Amount: testAmount, + Addr: tapScriptAddr.String(), + Amount: testAmount, + TargetConf: 6, } hn.RPC.SendCoins(req) diff --git a/lnrpc/rpc_utils.go b/lnrpc/rpc_utils.go index 9d9dea320f..9792cf35cf 100644 --- a/lnrpc/rpc_utils.go +++ b/lnrpc/rpc_utils.go @@ -222,12 +222,12 @@ func CalculateFeeRate(satPerByte, satPerVByte uint64, targetConf uint32, // Based on the passed fee related parameters, we'll determine an // appropriate fee rate for this transaction. - feeRate, err := sweep.DetermineFeePerKw( - estimator, sweep.FeePreference{ - ConfTarget: targetConf, - FeeRate: satPerKw, - }, - ) + feePref := sweep.FeeEstimateInfo{ + ConfTarget: targetConf, + FeeRate: satPerKw, + } + // TODO(yy): need to pass the configured max fee here. + feeRate, err := feePref.Estimate(estimator, 0) if err != nil { return feeRate, err } diff --git a/lnrpc/walletrpc/walletkit_server.go b/lnrpc/walletrpc/walletkit_server.go index 3e71070aae..e9d5ca66a9 100644 --- a/lnrpc/walletrpc/walletkit_server.go +++ b/lnrpc/walletrpc/walletkit_server.go @@ -889,7 +889,13 @@ func (w *WalletKit) PendingSweeps(ctx context.Context, broadcastAttempts := uint32(pendingInput.BroadcastAttempts) nextBroadcastHeight := uint32(pendingInput.NextBroadcastHeight) - requestedFee := pendingInput.Params.Fee + feePref := pendingInput.Params.Fee + requestedFee, ok := feePref.(sweep.FeeEstimateInfo) + if !ok { + return nil, fmt.Errorf("unknown fee preference type: "+ + "%v", feePref) + } + requestedFeeRate := uint64(requestedFee.FeeRate.FeePerVByte()) rpcPendingSweeps = append(rpcPendingSweeps, &PendingSweep{ @@ -970,7 +976,7 @@ func (w *WalletKit) BumpFee(ctx context.Context, in.SatPerByte * 1000, ).FeePerKWeight() } - feePreference := sweep.FeePreference{ + feePreference := sweep.FeeEstimateInfo{ ConfTarget: uint32(in.TargetConf), FeeRate: satPerKw, } diff --git a/lntest/harness.go b/lntest/harness.go index 1b9d18faec..6e0da417d5 100644 --- a/lntest/harness.go +++ b/lntest/harness.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb/etcd" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" @@ -922,6 +923,10 @@ type OpenChannelParams struct { // virtual byte of the transaction. SatPerVByte btcutil.Amount + // ConfTarget is the number of blocks that the funding transaction + // should be confirmed in. + ConfTarget fn.Option[int32] + // CommitmentType is the commitment type that should be used for the // channel to be opened. CommitmentType lnrpc.CommitmentType @@ -992,18 +997,27 @@ func (h *HarnessTest) prepareOpenChannel(srcNode, destNode *node.HarnessNode, minConfs = 0 } + // Get the requested conf target. If not set, default to 6. + confTarget := p.ConfTarget.UnwrapOr(6) + + // If there's fee rate set, unset the conf target. + if p.SatPerVByte != 0 { + confTarget = 0 + } + // Prepare the request. return &lnrpc.OpenChannelRequest{ NodePubkey: destNode.PubKey[:], LocalFundingAmount: int64(p.Amt), PushSat: int64(p.PushAmt), Private: p.Private, + TargetConf: confTarget, MinConfs: minConfs, SpendUnconfirmed: p.SpendUnconfirmed, MinHtlcMsat: int64(p.MinHtlc), RemoteMaxHtlcs: uint32(p.RemoteMaxHtlcs), FundingShim: p.FundingShim, - SatPerByte: int64(p.SatPerVByte), + SatPerVbyte: uint64(p.SatPerVByte), CommitmentType: p.CommitmentType, ZeroConf: p.ZeroConf, ScidAlias: p.ScidAlias, @@ -1210,6 +1224,11 @@ func (h *HarnessTest) CloseChannelAssertPending(hn *node.HarnessNode, NoWait: true, } + // For coop close, we use a default confg target of 6. + if !force { + closeReq.TargetConf = 6 + } + var ( stream rpc.CloseChanClient event *lnrpc.CloseStatusUpdate diff --git a/lnwallet/chainfee/mocks.go b/lnwallet/chainfee/mocks.go index 03d40e11e1..e14340d91f 100644 --- a/lnwallet/chainfee/mocks.go +++ b/lnwallet/chainfee/mocks.go @@ -1,6 +1,8 @@ package chainfee -import "github.com/stretchr/testify/mock" +import ( + "github.com/stretchr/testify/mock" +) type mockFeeSource struct { mock.Mock @@ -15,3 +17,50 @@ func (m *mockFeeSource) GetFeeMap() (map[uint32]uint32, error) { return args.Get(0).(map[uint32]uint32), args.Error(1) } + +// MockEstimator implements the `Estimator` interface and is used by +// other packages for mock testing. +type MockEstimator struct { + mock.Mock +} + +// Compile time assertion that MockEstimator implements Estimator. +var _ Estimator = (*MockEstimator)(nil) + +// EstimateFeePerKW takes in a target for the number of blocks until an initial +// confirmation and returns the estimated fee expressed in sat/kw. +func (m *MockEstimator) EstimateFeePerKW( + numBlocks uint32) (SatPerKWeight, error) { + + args := m.Called(numBlocks) + + if args.Get(0) == nil { + return 0, args.Error(1) + } + + return args.Get(0).(SatPerKWeight), args.Error(1) +} + +// Start signals the Estimator to start any processes or goroutines it needs to +// perform its duty. +func (m *MockEstimator) Start() error { + args := m.Called() + + return args.Error(0) +} + +// Stop stops any spawned goroutines and cleans up the resources used by the +// fee estimator. +func (m *MockEstimator) Stop() error { + args := m.Called() + + return args.Error(0) +} + +// RelayFeePerKW returns the minimum fee rate required for transactions to be +// relayed. This is also the basis for calculation of the dust limit. +func (m *MockEstimator) RelayFeePerKW() SatPerKWeight { + args := m.Called() + + return args.Get(0).(SatPerKWeight) +} diff --git a/rpcserver.go b/rpcserver.go index 8c3d6b05c5..6bb543644c 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1174,11 +1174,13 @@ func (r *rpcServer) EstimateFee(ctx context.Context, // Query the fee estimator for the fee rate for the given confirmation // target. target := in.TargetConf - feePerKw, err := sweep.DetermineFeePerKw( - r.server.cc.FeeEstimator, sweep.FeePreference{ - ConfTarget: uint32(target), - }, - ) + feePref := sweep.FeeEstimateInfo{ + ConfTarget: uint32(target), + } + + // Since we are providing a fee estimation as an RPC response, there's + // no need to set a max feerate here, so we use 0. + feePerKw, err := feePref.Estimate(r.server.cc.FeeEstimator, 0) if err != nil { return nil, err } @@ -2098,17 +2100,22 @@ func (r *rpcServer) parseOpenChannelReq(in *lnrpc.OpenChannelRequest, return nil, fmt.Errorf("cannot open channel to self") } - // Calculate an appropriate fee rate for this transaction. - feeRate, err := lnrpc.CalculateFeeRate( - uint64(in.SatPerByte), in.SatPerVbyte, // nolint:staticcheck - uint32(in.TargetConf), r.server.cc.FeeEstimator, - ) - if err != nil { - return nil, err - } + var feeRate chainfee.SatPerKWeight - rpcsLog.Debugf("[openchannel]: using fee of %v sat/kw for funding tx", - int64(feeRate)) + // Skip estimating fee rate for PSBT funding. + if in.FundingShim == nil || in.FundingShim.GetPsbtShim() == nil { + // Calculate an appropriate fee rate for this transaction. + feeRate, err = lnrpc.CalculateFeeRate( + uint64(in.SatPerByte), in.SatPerVbyte, + uint32(in.TargetConf), r.server.cc.FeeEstimator, + ) + if err != nil { + return nil, err + } + + rpcsLog.Debugf("[openchannel]: using fee of %v sat/kw for "+ + "funding tx", int64(feeRate)) + } script, err := chancloser.ParseUpfrontShutdownAddress( in.CloseAddress, r.cfg.ActiveNetParams.Params, diff --git a/server.go b/server.go index a089a22389..1eae274a56 100644 --- a/server.go +++ b/server.go @@ -1062,9 +1062,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } + aggregator := sweep.NewSimpleUtxoAggregator( + cc.FeeEstimator, cfg.Sweeper.MaxFeeRate.FeePerKWeight(), + ) + s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ FeeEstimator: cc.FeeEstimator, - DetermineFeePerKw: sweep.DetermineFeePerKw, GenSweepScript: newSweepPkScriptGen(cc.Wallet), Signer: cc.Wallet.Cfg.Signer, Wallet: newSweeperWallet(cc.Wallet), @@ -1075,7 +1078,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MaxSweepAttempts: sweep.DefaultMaxSweepAttempts, NextAttemptDeltaFunc: sweep.DefaultNextAttemptDeltaFunc, MaxFeeRate: cfg.Sweeper.MaxFeeRate, - FeeRateBucketSize: sweep.DefaultFeeRateBucketSize, + Aggregator: aggregator, }) s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ diff --git a/sweep/aggregator.go b/sweep/aggregator.go new file mode 100644 index 0000000000..6797e3573d --- /dev/null +++ b/sweep/aggregator.go @@ -0,0 +1,351 @@ +package sweep + +import ( + "sort" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +const ( + // DefaultFeeRateBucketSize is the default size of fee rate buckets + // we'll use when clustering inputs into buckets with similar fee rates + // within the SimpleAggregator. + // + // Given a minimum relay fee rate of 1 sat/vbyte, a multiplier of 10 + // would result in the following fee rate buckets up to the maximum fee + // rate: + // + // #1: min = 1 sat/vbyte, max = 10 sat/vbyte + // #2: min = 11 sat/vbyte, max = 20 sat/vbyte... + DefaultFeeRateBucketSize = 10 +) + +// UtxoAggregator defines an interface that takes a list of inputs and +// aggregate them into groups. Each group is used as the inputs to create a +// sweeping transaction. +type UtxoAggregator interface { + // ClusterInputs takes a list of inputs and groups them into clusters. + ClusterInputs(pendingInputs) []inputCluster +} + +// SimpleAggregator aggregates inputs known by the Sweeper based on each +// input's locktime and feerate. +type SimpleAggregator struct { + // FeeEstimator is used when crafting sweep transactions to estimate + // the necessary fee relative to the expected size of the sweep + // transaction. + FeeEstimator chainfee.Estimator + + // MaxFeeRate is the maximum fee rate allowed within the + // SimpleAggregator. + MaxFeeRate chainfee.SatPerKWeight + + // FeeRateBucketSize is the default size of fee rate buckets we'll use + // when clustering inputs into buckets with similar fee rates within + // the SimpleAggregator. + // + // Given a minimum relay fee rate of 1 sat/vbyte, a fee rate bucket + // size of 10 would result in the following fee rate buckets up to the + // maximum fee rate: + // + // #1: min = 1 sat/vbyte, max (exclusive) = 11 sat/vbyte + // #2: min = 11 sat/vbyte, max (exclusive) = 21 sat/vbyte... + FeeRateBucketSize int +} + +// Compile-time constraint to ensure SimpleAggregator implements UtxoAggregator. +var _ UtxoAggregator = (*SimpleAggregator)(nil) + +// NewSimpleUtxoAggregator creates a new instance of a SimpleAggregator. +func NewSimpleUtxoAggregator(estimator chainfee.Estimator, + max chainfee.SatPerKWeight) *SimpleAggregator { + + return &SimpleAggregator{ + FeeEstimator: estimator, + MaxFeeRate: max, + FeeRateBucketSize: DefaultFeeRateBucketSize, + } +} + +// ClusterInputs creates a list of input clusters from the set of pending +// inputs known by the UtxoSweeper. It clusters inputs by +// 1) Required tx locktime +// 2) Similar fee rates. +// +// TODO(yy): remove this nolint once done refactoring. +// +//nolint:revive +func (s *SimpleAggregator) ClusterInputs(inputs pendingInputs) []inputCluster { + // We start by getting the inputs clusters by locktime. Since the + // inputs commit to the locktime, they can only be clustered together + // if the locktime is equal. + lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs) + + // Cluster the remaining inputs by sweep fee rate. + feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs) + + // Since the inputs that we clustered by fee rate don't commit to a + // specific locktime, we can try to merge a locktime cluster with a fee + // cluster. + return zipClusters(lockTimeClusters, feeClusters) +} + +// clusterByLockTime takes the given set of pending inputs and clusters those +// with equal locktime together. Each cluster contains a sweep fee rate, which +// is determined by calculating the average fee rate of all inputs within that +// cluster. In addition to the created clusters, inputs that did not specify a +// required locktime are returned. +func (s *SimpleAggregator) clusterByLockTime( + inputs pendingInputs) ([]inputCluster, pendingInputs) { + + locktimes := make(map[uint32]pendingInputs) + rem := make(pendingInputs) + + // Go through all inputs and check if they require a certain locktime. + for op, input := range inputs { + lt, ok := input.RequiredLockTime() + if !ok { + rem[op] = input + continue + } + + // Check if we already have inputs with this locktime. + cluster, ok := locktimes[lt] + if !ok { + cluster = make(pendingInputs) + } + + // Get the fee rate based on the fee preference. If an error is + // returned, we'll skip sweeping this input for this round of + // cluster creation and retry it when we create the clusters + // from the pending inputs again. + feeRate, err := input.params.Fee.Estimate( + s.FeeEstimator, s.MaxFeeRate, + ) + if err != nil { + log.Warnf("Skipping input %v: %v", op, err) + continue + } + + log.Debugf("Adding input %v to cluster with locktime=%v, "+ + "feeRate=%v", op, lt, feeRate) + + // Attach the fee rate to the input. + input.lastFeeRate = feeRate + + // Update the cluster about the updated input. + cluster[op] = input + locktimes[lt] = cluster + } + + // We'll then determine the sweep fee rate for each set of inputs by + // calculating the average fee rate of the inputs within each set. + inputClusters := make([]inputCluster, 0, len(locktimes)) + for lt, cluster := range locktimes { + lt := lt + + var sweepFeeRate chainfee.SatPerKWeight + for _, input := range cluster { + sweepFeeRate += input.lastFeeRate + } + + sweepFeeRate /= chainfee.SatPerKWeight(len(cluster)) + inputClusters = append(inputClusters, inputCluster{ + lockTime: <, + sweepFeeRate: sweepFeeRate, + inputs: cluster, + }) + } + + return inputClusters, rem +} + +// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper +// and clusters those together with similar fee rates. Each cluster contains a +// sweep fee rate, which is determined by calculating the average fee rate of +// all inputs within that cluster. +func (s *SimpleAggregator) clusterBySweepFeeRate( + inputs pendingInputs) []inputCluster { + + bucketInputs := make(map[int]*bucketList) + inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight) + + // First, we'll group together all inputs with similar fee rates. This + // is done by determining the fee rate bucket they should belong in. + for op, input := range inputs { + feeRate, err := input.params.Fee.Estimate( + s.FeeEstimator, s.MaxFeeRate, + ) + if err != nil { + log.Warnf("Skipping input %v: %v", op, err) + continue + } + + // Only try to sweep inputs with an unconfirmed parent if the + // current sweep fee rate exceeds the parent tx fee rate. This + // assumes that such inputs are offered to the sweeper solely + // for the purpose of anchoring down the parent tx using cpfp. + parentTx := input.UnconfParent() + if parentTx != nil { + parentFeeRate := + chainfee.SatPerKWeight(parentTx.Fee*1000) / + chainfee.SatPerKWeight(parentTx.Weight) + + if parentFeeRate >= feeRate { + log.Debugf("Skipping cpfp input %v: "+ + "fee_rate=%v, parent_fee_rate=%v", op, + feeRate, parentFeeRate) + + continue + } + } + + feeGroup := s.bucketForFeeRate(feeRate) + + // Create a bucket list for this fee rate if there isn't one + // yet. + buckets, ok := bucketInputs[feeGroup] + if !ok { + buckets = &bucketList{} + bucketInputs[feeGroup] = buckets + } + + // Request the bucket list to add this input. The bucket list + // will take into account exclusive group constraints. + buckets.add(input) + + input.lastFeeRate = feeRate + inputFeeRates[op] = feeRate + } + + // We'll then determine the sweep fee rate for each set of inputs by + // calculating the average fee rate of the inputs within each set. + inputClusters := make([]inputCluster, 0, len(bucketInputs)) + for _, buckets := range bucketInputs { + for _, inputs := range buckets.buckets { + var sweepFeeRate chainfee.SatPerKWeight + for op := range inputs { + sweepFeeRate += inputFeeRates[op] + } + sweepFeeRate /= chainfee.SatPerKWeight(len(inputs)) + inputClusters = append(inputClusters, inputCluster{ + sweepFeeRate: sweepFeeRate, + inputs: inputs, + }) + } + } + + return inputClusters +} + +// bucketForFeeReate determines the proper bucket for a fee rate. This is done +// in order to batch inputs with similar fee rates together. +func (s *SimpleAggregator) bucketForFeeRate( + feeRate chainfee.SatPerKWeight) int { + + relayFeeRate := s.FeeEstimator.RelayFeePerKW() + + // Create an isolated bucket for sweeps at the minimum fee rate. This + // is to prevent very small outputs (anchors) from becoming + // uneconomical if their fee rate would be averaged with higher fee + // rate inputs in a regular bucket. + if feeRate == relayFeeRate { + return 0 + } + + return 1 + int(feeRate-relayFeeRate)/s.FeeRateBucketSize +} + +// mergeClusters attempts to merge cluster a and b if they are compatible. The +// new cluster will have the locktime set if a or b had a locktime set, and a +// sweep fee rate that is the maximum of a and b's. If the two clusters are not +// compatible, they will be returned unchanged. +func mergeClusters(a, b inputCluster) []inputCluster { + newCluster := inputCluster{} + + switch { + // Incompatible locktimes, return the sets without merging them. + case a.lockTime != nil && b.lockTime != nil && + *a.lockTime != *b.lockTime: + + return []inputCluster{a, b} + + case a.lockTime != nil: + newCluster.lockTime = a.lockTime + + case b.lockTime != nil: + newCluster.lockTime = b.lockTime + } + + if a.sweepFeeRate > b.sweepFeeRate { + newCluster.sweepFeeRate = a.sweepFeeRate + } else { + newCluster.sweepFeeRate = b.sweepFeeRate + } + + newCluster.inputs = make(pendingInputs) + + for op, in := range a.inputs { + newCluster.inputs[op] = in + } + + for op, in := range b.inputs { + newCluster.inputs[op] = in + } + + return []inputCluster{newCluster} +} + +// zipClusters merges pairwise clusters from as and bs such that cluster a from +// as is merged with a cluster from bs that has at least the fee rate of a. +// This to ensure we don't delay confirmation by decreasing the fee rate (the +// lock time inputs are typically second level HTLC transactions, that are time +// sensitive). +func zipClusters(as, bs []inputCluster) []inputCluster { + // Sort the clusters by decreasing fee rates. + sort.Slice(as, func(i, j int) bool { + return as[i].sweepFeeRate > + as[j].sweepFeeRate + }) + sort.Slice(bs, func(i, j int) bool { + return bs[i].sweepFeeRate > + bs[j].sweepFeeRate + }) + + var ( + finalClusters []inputCluster + j int + ) + + // Go through each cluster in as, and merge with the next one from bs + // if it has at least the fee rate needed. + for i := range as { + a := as[i] + + switch { + // If the fee rate for the next one from bs is at least a's, we + // merge. + case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate: + merged := mergeClusters(a, bs[j]) + finalClusters = append(finalClusters, merged...) + + // Increment j for the next round. + j++ + + // We did not merge, meaning all the remaining clusters from bs + // have lower fee rate. Instead we add a directly to the final + // clusters. + default: + finalClusters = append(finalClusters, a) + } + } + + // Add any remaining clusters from bs. + for ; j < len(bs); j++ { + b := bs[j] + finalClusters = append(finalClusters, b) + } + + return finalClusters +} diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go new file mode 100644 index 0000000000..f3bf2cd288 --- /dev/null +++ b/sweep/aggregator_test.go @@ -0,0 +1,423 @@ +package sweep + +import ( + "errors" + "reflect" + "sort" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +//nolint:lll +var ( + testInputsA = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + } + + testInputsB = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } + + testInputsC = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } +) + +// TestMergeClusters check that we properly can merge clusters together, +// according to their required locktime. +func TestMergeClusters(t *testing.T) { + t.Parallel() + + lockTime1 := uint32(100) + lockTime2 := uint32(200) + + testCases := []struct { + name string + a inputCluster + b inputCluster + res []inputCluster + }{ + { + name: "max fee rate", + a: inputCluster{ + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "same locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "diff locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + { + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + }, + }, + } + + for _, test := range testCases { + merged := mergeClusters(test.a, test.b) + if !reflect.DeepEqual(merged, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(merged)) + } + } +} + +// TestZipClusters tests that we can merge lists of inputs clusters correctly. +func TestZipClusters(t *testing.T) { + t.Parallel() + + createCluster := func(inp pendingInputs, + f chainfee.SatPerKWeight) inputCluster { + + return inputCluster{ + sweepFeeRate: f, + inputs: inp, + } + } + + testCases := []struct { + name string + as []inputCluster + bs []inputCluster + res []inputCluster + }{ + { + name: "merge A into B", + as: []inputCluster{ + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + }, + }, + { + name: "A can't merge with B", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsA, 7000), + createCluster(testInputsB, 5000), + }, + }, + { + name: "empty bs", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{}, + res: []inputCluster{ + createCluster(testInputsA, 7000), + }, + }, + { + name: "empty as", + as: []inputCluster{}, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsB, 5000), + }, + }, + + { + name: "zip 3xA into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + }, + }, + { + name: "zip A into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 2500), + }, + bs: []inputCluster{ + createCluster(testInputsB, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + res: []inputCluster{ + createCluster(testInputsC, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + }, + } + + for _, test := range testCases { + zipped := zipClusters(test.as, test.bs) + if !reflect.DeepEqual(zipped, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(zipped)) + } + } +} + +// TestClusterByLockTime tests the method clusterByLockTime works as expected. +func TestClusterByLockTime(t *testing.T) { + t.Parallel() + + // Create a mock FeePreference. + mockFeePref := &MockFeePreference{} + + // Create a test param with a dummy fee preference. This is needed so + // `feeRateForPreference` won't throw an error. + param := Params{Fee: mockFeePref} + + // We begin the test by creating three clusters of inputs, the first + // cluster has a locktime of 1, the second has a locktime of 2, and the + // final has no locktime. + lockTime1 := uint32(1) + lockTime2 := uint32(2) + + // Create cluster one, which has a locktime of 1. + input1LockTime1 := &input.MockInput{} + input2LockTime1 := &input.MockInput{} + input1LockTime1.On("RequiredLockTime").Return(lockTime1, true) + input2LockTime1.On("RequiredLockTime").Return(lockTime1, true) + + // Create cluster two, which has a locktime of 2. + input3LockTime2 := &input.MockInput{} + input4LockTime2 := &input.MockInput{} + input3LockTime2.On("RequiredLockTime").Return(lockTime2, true) + input4LockTime2.On("RequiredLockTime").Return(lockTime2, true) + + // Create cluster three, which has no locktime. + input5NoLockTime := &input.MockInput{} + input6NoLockTime := &input.MockInput{} + input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + + // With the inner Input being mocked, we can now create the pending + // inputs. + input1 := &pendingInput{Input: input1LockTime1, params: param} + input2 := &pendingInput{Input: input2LockTime1, params: param} + input3 := &pendingInput{Input: input3LockTime2, params: param} + input4 := &pendingInput{Input: input4LockTime2, params: param} + input5 := &pendingInput{Input: input5NoLockTime, params: param} + input6 := &pendingInput{Input: input6NoLockTime, params: param} + + // Create the pending inputs map, which will be passed to the method + // under test. + // + // NOTE: we don't care the actual outpoint values as long as they are + // unique. + inputs := pendingInputs{ + wire.OutPoint{Index: 1}: input1, + wire.OutPoint{Index: 2}: input2, + wire.OutPoint{Index: 3}: input3, + wire.OutPoint{Index: 4}: input4, + wire.OutPoint{Index: 5}: input5, + wire.OutPoint{Index: 6}: input6, + } + + // Create expected clusters so we can shorten the line length in the + // test cases below. + cluster1 := pendingInputs{ + wire.OutPoint{Index: 1}: input1, + wire.OutPoint{Index: 2}: input2, + } + cluster2 := pendingInputs{ + wire.OutPoint{Index: 3}: input3, + wire.OutPoint{Index: 4}: input4, + } + + // cluster3 should be the remaining inputs since they don't have + // locktime. + cluster3 := pendingInputs{ + wire.OutPoint{Index: 5}: input5, + wire.OutPoint{Index: 6}: input6, + } + + const ( + // Set the min fee rate to be 1000 sat/kw. + minFeeRate = chainfee.SatPerKWeight(1000) + + // Set the max fee rate to be 10,000 sat/kw. + maxFeeRate = chainfee.SatPerKWeight(10_000) + ) + + // Create a test aggregator. + s := NewSimpleUtxoAggregator(nil, maxFeeRate) + + testCases := []struct { + name string + // setupMocker takes a testing fee rate and makes a mocker over + // `Estimate` that always return the testing fee rate. + setupMocker func() + testFeeRate chainfee.SatPerKWeight + expectedClusters []inputCluster + expectedRemainingInputs pendingInputs + }{ + { + // Test a successful case where the locktime clusters + // are created and the no-locktime cluster is returned + // as the remaining inputs. + name: "successfully create clusters", + setupMocker: func() { + // Expect the four inputs with locktime to call + // this method. + mockFeePref.On("Estimate", nil, maxFeeRate). + Return(minFeeRate+1, nil).Times(4) + }, + // Use a fee rate above the min value so we don't hit + // an error when performing fee estimation. + // + // TODO(yy): we should customize the returned fee rate + // for each input to further test the averaging logic. + // Or we can split the method into two, one for + // grouping the clusters and the other for averaging + // the fee rates so it's easier to be tested. + testFeeRate: minFeeRate + 1, + expectedClusters: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: minFeeRate + 1, + inputs: cluster1, + }, + { + lockTime: &lockTime2, + sweepFeeRate: minFeeRate + 1, + inputs: cluster2, + }, + }, + expectedRemainingInputs: cluster3, + }, + { + // Test that when the input is skipped when the fee + // estimation returns an error. + name: "error from fee estimation", + setupMocker: func() { + mockFeePref.On("Estimate", nil, maxFeeRate). + Return(chainfee.SatPerKWeight(0), + errors.New("dummy")).Times(4) + }, + + // Use a fee rate below the min value so we hit an + // error when performing fee estimation. + testFeeRate: minFeeRate - 1, + expectedClusters: []inputCluster{}, + // Remaining inputs should stay untouched. + expectedRemainingInputs: cluster3, + }, + } + + //nolint:paralleltest + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Apply the test fee rate so `feeRateForPreference` is + // mocked to return the specified value. + tc.setupMocker() + + // Assert the mocked methods are called as expeceted. + defer mockFeePref.AssertExpectations(t) + + // Call the method under test. + clusters, remainingInputs := s.clusterByLockTime(inputs) + + // Sort by locktime as the order is not guaranteed. + sort.Slice(clusters, func(i, j int) bool { + return *clusters[i].lockTime < + *clusters[j].lockTime + }) + + // Validate the values are returned as expected. + require.Equal(t, tc.expectedClusters, clusters) + require.Equal(t, tc.expectedRemainingInputs, + remainingInputs, + ) + + // Assert the mocked methods are called as expeceted. + input1LockTime1.AssertExpectations(t) + input2LockTime1.AssertExpectations(t) + input3LockTime2.AssertExpectations(t) + input4LockTime2.AssertExpectations(t) + input5NoLockTime.AssertExpectations(t) + input6NoLockTime.AssertExpectations(t) + }) + } +} diff --git a/sweep/fee_estimator_mock_test.go b/sweep/fee_estimator_mock_test.go index 4ca89f0c5b..ab6dcdfd50 100644 --- a/sweep/fee_estimator_mock_test.go +++ b/sweep/fee_estimator_mock_test.go @@ -9,6 +9,8 @@ import ( // mockFeeEstimator implements a mock fee estimator. It closely resembles // lnwallet.StaticFeeEstimator with the addition that fees can be changed for // testing purposes in a thread safe manner. +// +// TODO(yy): replace it with chainfee.MockEstimator once it's merged. type mockFeeEstimator struct { feePerKW chainfee.SatPerKWeight diff --git a/sweep/mocks.go b/sweep/mocks.go new file mode 100644 index 0000000000..3c88823087 --- /dev/null +++ b/sweep/mocks.go @@ -0,0 +1,44 @@ +package sweep + +import ( + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/mock" +) + +type MockFeePreference struct { + mock.Mock +} + +// Compile-time constraint to ensure MockFeePreference implements FeePreference. +var _ FeePreference = (*MockFeePreference)(nil) + +func (m *MockFeePreference) String() string { + return "mock fee preference" +} + +func (m *MockFeePreference) Estimate(estimator chainfee.Estimator, + maxFeeRate chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) { + + args := m.Called(estimator, maxFeeRate) + + if args.Get(0) == nil { + return 0, args.Error(1) + } + + return args.Get(0).(chainfee.SatPerKWeight), args.Error(1) +} + +type mockUtxoAggregator struct { + mock.Mock +} + +// Compile-time constraint to ensure mockUtxoAggregator implements +// UtxoAggregator. +var _ UtxoAggregator = (*mockUtxoAggregator)(nil) + +// ClusterInputs takes a list of inputs and groups them into clusters. +func (m *mockUtxoAggregator) ClusterInputs(pendingInputs) []inputCluster { + args := m.Called(pendingInputs{}) + + return args.Get(0).([]inputCluster) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index f3920a4148..69f3403e23 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -20,20 +20,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) -const ( - // DefaultFeeRateBucketSize is the default size of fee rate buckets - // we'll use when clustering inputs into buckets with similar fee rates - // within the UtxoSweeper. - // - // Given a minimum relay fee rate of 1 sat/vbyte, a multiplier of 10 - // would result in the following fee rate buckets up to the maximum fee - // rate: - // - // #1: min = 1 sat/vbyte, max = 10 sat/vbyte - // #2: min = 11 sat/vbyte, max = 20 sat/vbyte... - DefaultFeeRateBucketSize = 10 -) - var ( // ErrRemoteSpend is returned in case an output that we try to sweep is // confirmed in a tx of the remote party. @@ -43,10 +29,6 @@ var ( // for the configured max number of attempts. ErrTooManyAttempts = errors.New("sweep failed after max attempts") - // ErrNoFeePreference is returned when we attempt to satisfy a sweep - // request from a client whom did not specify a fee preference. - ErrNoFeePreference = errors.New("no fee preference specified") - // ErrFeePreferenceTooLow is returned when the fee preference gives a // fee rate that's below the relay fee rate. ErrFeePreferenceTooLow = errors.New("fee preference too low") @@ -236,12 +218,11 @@ type UtxoSweeper struct { quit chan struct{} wg sync.WaitGroup -} -// feeDeterminer defines an alias to the function signature of -// `DetermineFeePerKw`. -type feeDeterminer func(chainfee.Estimator, - FeePreference) (chainfee.SatPerKWeight, error) + // currentHeight is the best known height of the main chain. This is + // updated whenever a new block epoch is received. + currentHeight int32 +} // UtxoSweeperConfig contains dependencies of UtxoSweeper. type UtxoSweeperConfig struct { @@ -249,10 +230,6 @@ type UtxoSweeperConfig struct { // funds can be swept. GenSweepScript func() ([]byte, error) - // DetermineFeePerKw determines the fee in sat/kw based on the given - // estimator and fee preference. - DetermineFeePerKw feeDeterminer - // FeeEstimator is used when crafting sweep transactions to estimate // the necessary fee relative to the expected size of the sweep // transaction. @@ -296,17 +273,9 @@ type UtxoSweeperConfig struct { // UtxoSweeper. MaxFeeRate chainfee.SatPerVByte - // FeeRateBucketSize is the default size of fee rate buckets we'll use - // when clustering inputs into buckets with similar fee rates within the - // UtxoSweeper. - // - // Given a minimum relay fee rate of 1 sat/vbyte, a fee rate bucket size - // of 10 would result in the following fee rate buckets up to the - // maximum fee rate: - // - // #1: min = 1 sat/vbyte, max (exclusive) = 11 sat/vbyte - // #2: min = 11 sat/vbyte, max (exclusive) = 21 sat/vbyte... - FeeRateBucketSize int + // Aggregator is used to group inputs into clusters based on its + // implemention-specific strategy. + Aggregator UtxoAggregator } // Result is the struct that is pushed through the result channel. Callers can @@ -442,7 +411,10 @@ func (s *UtxoSweeper) SweepInput(input input.Input, } // Ensure the client provided a sane fee preference. - if _, err := s.feeRateForPreference(params.Fee); err != nil { + _, err := params.Fee.Estimate( + s.cfg.FeeEstimator, s.cfg.MaxFeeRate.FeePerKWeight(), + ) + if err != nil { return nil, err } @@ -470,42 +442,6 @@ func (s *UtxoSweeper) SweepInput(input input.Input, return sweeperInput.resultChan, nil } -// feeRateForPreference returns a fee rate for the given fee preference. It -// ensures that the fee rate respects the bounds of the UtxoSweeper. -func (s *UtxoSweeper) feeRateForPreference( - feePreference FeePreference) (chainfee.SatPerKWeight, error) { - - // Ensure a type of fee preference is specified to prevent using a - // default below. - if feePreference.FeeRate == 0 && feePreference.ConfTarget == 0 { - return 0, ErrNoFeePreference - } - - feeRate, err := s.cfg.DetermineFeePerKw( - s.cfg.FeeEstimator, feePreference, - ) - if err != nil { - return 0, err - } - - if feeRate < s.relayFeeRate { - return 0, fmt.Errorf("%w: got %v, minimum is %v", - ErrFeePreferenceTooLow, feeRate, s.relayFeeRate) - } - - // If the estimated fee rate is above the maximum allowed fee rate, - // default to the max fee rate. - if feeRate > s.cfg.MaxFeeRate.FeePerKWeight() { - log.Warnf("Estimated fee rate %v exceeds max allowed fee "+ - "rate %v, using max fee rate instead", feeRate, - s.cfg.MaxFeeRate.FeePerKWeight()) - - return s.cfg.MaxFeeRate.FeePerKWeight(), nil - } - - return feeRate, nil -} - // removeConflictSweepDescendants removes any transactions from the wallet that // spend outputs included in the passed outpoint set. This needs to be done in // cases where we're not the only ones that can sweep an output, but there may @@ -596,10 +532,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // We registered for the block epochs with a nil request. The notifier // should send us the current best block immediately. So we need to wait // for it here because we need to know the current best height. - var bestHeight int32 select { case bestBlock := <-blockEpochs: - bestHeight = bestBlock.Height + s.currentHeight = bestBlock.Height case <-s.quit: return @@ -617,7 +552,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // we are already trying to sweep this input and if not, set up // a listener to spend and schedule a sweep. case input := <-s.newInputs: - s.handleNewInput(input, bestHeight) + s.handleNewInput(input) // A spend of one of our inputs is detected. Signal sweep // results to the caller(s). @@ -632,7 +567,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // A new external request has been received to bump the fee rate // of a given input. case req := <-s.updateReqs: - resultChan, err := s.handleUpdateReq(req, bestHeight) + resultChan, err := s.handleUpdateReq(req) req.responseChan <- &updateResp{ resultChan: resultChan, err: err, @@ -641,7 +576,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // The timer expires and we are going to (re)sweep. case <-ticker.C: log.Debugf("Sweep ticker ticks, attempt sweeping...") - s.handleSweep(bestHeight) + s.handleSweep() // A new block comes in, update the bestHeight. case epoch, ok := <-blockEpochs: @@ -649,7 +584,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { return } - bestHeight = epoch.Height + s.currentHeight = epoch.Height log.Debugf("New block: height=%v, sha=%v", epoch.Height, epoch.Hash) @@ -698,15 +633,13 @@ func (s *UtxoSweeper) removeExclusiveGroup(group uint64) { } // sweepCluster tries to sweep the given input cluster. -func (s *UtxoSweeper) sweepCluster(cluster inputCluster, - currentHeight int32) error { - +func (s *UtxoSweeper) sweepCluster(cluster inputCluster) error { // Execute the sweep within a coin select lock. Otherwise the coins // that we are going to spend may be selected for other transactions // like funding of a channel. return s.cfg.Wallet.WithCoinSelectLock(func() error { // Examine pending inputs and try to construct lists of inputs. - allSets, newSets, err := s.getInputLists(cluster, currentHeight) + allSets, newSets, err := s.getInputLists(cluster) if err != nil { return fmt.Errorf("examine pending inputs: %w", err) } @@ -719,9 +652,7 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster, // creating an RBF for the new inputs, we'd sweep this set // first. for _, inputs := range allSets { - errAllSets = s.sweep( - inputs, cluster.sweepFeeRate, currentHeight, - ) + errAllSets = s.sweep(inputs, cluster.sweepFeeRate) // TODO(yy): we should also find out which set created // this error. If there are new inputs in this set, we // should give it a second chance by sweeping them @@ -754,9 +685,7 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster, // when sweeping a given set, we'd log the error and sweep the // next set. for _, inputs := range newSets { - err := s.sweep( - inputs, cluster.sweepFeeRate, currentHeight, - ) + err := s.sweep(inputs, cluster.sweepFeeRate) if err != nil { log.Errorf("sweep new inputs: %w", err) } @@ -766,276 +695,6 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster, }) } -// bucketForFeeReate determines the proper bucket for a fee rate. This is done -// in order to batch inputs with similar fee rates together. -func (s *UtxoSweeper) bucketForFeeRate( - feeRate chainfee.SatPerKWeight) int { - - // Create an isolated bucket for sweeps at the minimum fee rate. This is - // to prevent very small outputs (anchors) from becoming uneconomical if - // their fee rate would be averaged with higher fee rate inputs in a - // regular bucket. - if feeRate == s.relayFeeRate { - return 0 - } - - return 1 + int(feeRate-s.relayFeeRate)/s.cfg.FeeRateBucketSize -} - -// createInputClusters creates a list of input clusters from the set of pending -// inputs known by the UtxoSweeper. It clusters inputs by -// 1) Required tx locktime -// 2) Similar fee rates. -func (s *UtxoSweeper) createInputClusters() []inputCluster { - inputs := s.pendingInputs - - // We start by getting the inputs clusters by locktime. Since the - // inputs commit to the locktime, they can only be clustered together - // if the locktime is equal. - lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs) - - // Cluster the remaining inputs by sweep fee rate. - feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs) - - // Since the inputs that we clustered by fee rate don't commit to a - // specific locktime, we can try to merge a locktime cluster with a fee - // cluster. - return zipClusters(lockTimeClusters, feeClusters) -} - -// clusterByLockTime takes the given set of pending inputs and clusters those -// with equal locktime together. Each cluster contains a sweep fee rate, which -// is determined by calculating the average fee rate of all inputs within that -// cluster. In addition to the created clusters, inputs that did not specify a -// required lock time are returned. -func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster, - pendingInputs) { - - locktimes := make(map[uint32]pendingInputs) - rem := make(pendingInputs) - - // Go through all inputs and check if they require a certain locktime. - for op, input := range inputs { - lt, ok := input.RequiredLockTime() - if !ok { - rem[op] = input - continue - } - - // Check if we already have inputs with this locktime. - cluster, ok := locktimes[lt] - if !ok { - cluster = make(pendingInputs) - } - - // Get the fee rate based on the fee preference. If an error is - // returned, we'll skip sweeping this input for this round of - // cluster creation and retry it when we create the clusters - // from the pending inputs again. - feeRate, err := s.feeRateForPreference(input.params.Fee) - if err != nil { - log.Warnf("Skipping input %v: %v", op, err) - continue - } - - log.Debugf("Adding input %v to cluster with locktime=%v, "+ - "feeRate=%v", op, lt, feeRate) - - // Attach the fee rate to the input. - input.lastFeeRate = feeRate - - // Update the cluster about the updated input. - cluster[op] = input - locktimes[lt] = cluster - } - - // We'll then determine the sweep fee rate for each set of inputs by - // calculating the average fee rate of the inputs within each set. - inputClusters := make([]inputCluster, 0, len(locktimes)) - for lt, cluster := range locktimes { - lt := lt - - var sweepFeeRate chainfee.SatPerKWeight - for _, input := range cluster { - sweepFeeRate += input.lastFeeRate - } - - sweepFeeRate /= chainfee.SatPerKWeight(len(cluster)) - inputClusters = append(inputClusters, inputCluster{ - lockTime: <, - sweepFeeRate: sweepFeeRate, - inputs: cluster, - }) - } - - return inputClusters, rem -} - -// clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper -// and clusters those together with similar fee rates. Each cluster contains a -// sweep fee rate, which is determined by calculating the average fee rate of -// all inputs within that cluster. -func (s *UtxoSweeper) clusterBySweepFeeRate(inputs pendingInputs) []inputCluster { - bucketInputs := make(map[int]*bucketList) - inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight) - - // First, we'll group together all inputs with similar fee rates. This - // is done by determining the fee rate bucket they should belong in. - for op, input := range inputs { - feeRate, err := s.feeRateForPreference(input.params.Fee) - if err != nil { - log.Warnf("Skipping input %v: %v", op, err) - continue - } - - // Only try to sweep inputs with an unconfirmed parent if the - // current sweep fee rate exceeds the parent tx fee rate. This - // assumes that such inputs are offered to the sweeper solely - // for the purpose of anchoring down the parent tx using cpfp. - parentTx := input.UnconfParent() - if parentTx != nil { - parentFeeRate := - chainfee.SatPerKWeight(parentTx.Fee*1000) / - chainfee.SatPerKWeight(parentTx.Weight) - - if parentFeeRate >= feeRate { - log.Debugf("Skipping cpfp input %v: fee_rate=%v, "+ - "parent_fee_rate=%v", op, feeRate, - parentFeeRate) - - continue - } - } - - feeGroup := s.bucketForFeeRate(feeRate) - - // Create a bucket list for this fee rate if there isn't one - // yet. - buckets, ok := bucketInputs[feeGroup] - if !ok { - buckets = &bucketList{} - bucketInputs[feeGroup] = buckets - } - - // Request the bucket list to add this input. The bucket list - // will take into account exclusive group constraints. - buckets.add(input) - - input.lastFeeRate = feeRate - inputFeeRates[op] = feeRate - } - - // We'll then determine the sweep fee rate for each set of inputs by - // calculating the average fee rate of the inputs within each set. - inputClusters := make([]inputCluster, 0, len(bucketInputs)) - for _, buckets := range bucketInputs { - for _, inputs := range buckets.buckets { - var sweepFeeRate chainfee.SatPerKWeight - for op := range inputs { - sweepFeeRate += inputFeeRates[op] - } - sweepFeeRate /= chainfee.SatPerKWeight(len(inputs)) - inputClusters = append(inputClusters, inputCluster{ - sweepFeeRate: sweepFeeRate, - inputs: inputs, - }) - } - } - - return inputClusters -} - -// zipClusters merges pairwise clusters from as and bs such that cluster a from -// as is merged with a cluster from bs that has at least the fee rate of a. -// This to ensure we don't delay confirmation by decreasing the fee rate (the -// lock time inputs are typically second level HTLC transactions, that are time -// sensitive). -func zipClusters(as, bs []inputCluster) []inputCluster { - // Sort the clusters by decreasing fee rates. - sort.Slice(as, func(i, j int) bool { - return as[i].sweepFeeRate > - as[j].sweepFeeRate - }) - sort.Slice(bs, func(i, j int) bool { - return bs[i].sweepFeeRate > - bs[j].sweepFeeRate - }) - - var ( - finalClusters []inputCluster - j int - ) - - // Go through each cluster in as, and merge with the next one from bs - // if it has at least the fee rate needed. - for i := range as { - a := as[i] - - switch { - // If the fee rate for the next one from bs is at least a's, we - // merge. - case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate: - merged := mergeClusters(a, bs[j]) - finalClusters = append(finalClusters, merged...) - - // Increment j for the next round. - j++ - - // We did not merge, meaning all the remaining clusters from bs - // have lower fee rate. Instead we add a directly to the final - // clusters. - default: - finalClusters = append(finalClusters, a) - } - } - - // Add any remaining clusters from bs. - for ; j < len(bs); j++ { - b := bs[j] - finalClusters = append(finalClusters, b) - } - - return finalClusters -} - -// mergeClusters attempts to merge cluster a and b if they are compatible. The -// new cluster will have the locktime set if a or b had a locktime set, and a -// sweep fee rate that is the maximum of a and b's. If the two clusters are not -// compatible, they will be returned unchanged. -func mergeClusters(a, b inputCluster) []inputCluster { - newCluster := inputCluster{} - - switch { - // Incompatible locktimes, return the sets without merging them. - case a.lockTime != nil && b.lockTime != nil && *a.lockTime != *b.lockTime: - return []inputCluster{a, b} - - case a.lockTime != nil: - newCluster.lockTime = a.lockTime - - case b.lockTime != nil: - newCluster.lockTime = b.lockTime - } - - if a.sweepFeeRate > b.sweepFeeRate { - newCluster.sweepFeeRate = a.sweepFeeRate - } else { - newCluster.sweepFeeRate = b.sweepFeeRate - } - - newCluster.inputs = make(pendingInputs) - - for op, in := range a.inputs { - newCluster.inputs[op] = in - } - - for op, in := range b.inputs { - newCluster.inputs[op] = in - } - - return []inputCluster{newCluster} -} - // signalAndRemove notifies the listeners of the final result of the input // sweep. It cancels any pending spend notification and removes the input from // the list of pending inputs. When this function returns, the sweeper has @@ -1079,8 +738,8 @@ func (s *UtxoSweeper) signalAndRemove(outpoint *wire.OutPoint, result Result) { // and will be bundled with future inputs if possible. It returns two list - // one containing all inputs and the other containing only the new inputs. If // there's no retried inputs, the first set returned will be empty. -func (s *UtxoSweeper) getInputLists(cluster inputCluster, - currentHeight int32) ([]inputSet, []inputSet, error) { +func (s *UtxoSweeper) getInputLists( + cluster inputCluster) ([]inputSet, []inputSet, error) { // Filter for inputs that need to be swept. Create two lists: all // sweepable inputs and a list containing only the new, never tried @@ -1102,7 +761,7 @@ func (s *UtxoSweeper) getInputLists(cluster inputCluster, for _, input := range cluster.inputs { // Skip inputs that have a minimum publish height that is not // yet reached. - if input.minPublishHeight > currentHeight { + if input.minPublishHeight > s.currentHeight { continue } @@ -1143,15 +802,15 @@ func (s *UtxoSweeper) getInputLists(cluster inputCluster, } log.Debugf("Sweep candidates at height=%v: total_num_pending=%v, "+ - "total_num_new=%v", currentHeight, len(allSets), len(newSets)) + "total_num_new=%v", s.currentHeight, len(allSets), len(newSets)) return allSets, newSets, nil } // sweep takes a set of preselected inputs, creates a sweep tx and publishes the // tx. The output address is only marked as used if the publish succeeds. -func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight, - currentHeight int32) error { +func (s *UtxoSweeper) sweep(inputs inputSet, + feeRate chainfee.SatPerKWeight) error { // Generate an output script if there isn't an unused script available. if s.currentOutputScript == nil { @@ -1164,7 +823,7 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight, // Create sweep tx. tx, fee, err := createSweepTx( - inputs, nil, s.currentOutputScript, uint32(currentHeight), + inputs, nil, s.currentOutputScript, uint32(s.currentHeight), feeRate, s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer, ) if err != nil { @@ -1190,10 +849,10 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight, // Reschedule the inputs that we just tried to sweep. This is done in // case the following publish fails, we'd like to update the inputs' // publish attempts and rescue them in the next sweep. - s.rescheduleInputs(tx.TxIn, currentHeight) + s.rescheduleInputs(tx.TxIn) log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", - tx.TxHash(), len(tx.TxIn), currentHeight) + tx.TxHash(), len(tx.TxIn), s.currentHeight) // Publish the sweeping tx with customized label. err = s.cfg.Wallet.PublishTransaction( @@ -1225,9 +884,7 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight, // increments the `publishAttempts` and calculates the next broadcast height // for each input. When the publishAttempts exceeds MaxSweepAttemps(10), this // input will be removed. -func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn, - currentHeight int32) { - +func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn) { // Reschedule sweep. for _, input := range inputs { pi, ok := s.pendingInputs[input.PreviousOutPoint] @@ -1251,7 +908,7 @@ func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn, pi.publishAttempts, ) - pi.minPublishHeight = currentHeight + nextAttemptDelta + pi.minPublishHeight = s.currentHeight + nextAttemptDelta log.Debugf("Rescheduling input %v after %v attempts at "+ "height %v (delta %v)", input.PreviousOutPoint, @@ -1378,7 +1035,10 @@ func (s *UtxoSweeper) UpdateParams(input wire.OutPoint, params ParamsUpdate) (chan Result, error) { // Ensure the client provided a sane fee preference. - if _, err := s.feeRateForPreference(params.Fee); err != nil { + _, err := params.Fee.Estimate( + s.cfg.FeeEstimator, s.cfg.MaxFeeRate.FeePerKWeight(), + ) + if err != nil { return nil, err } @@ -1412,7 +1072,7 @@ func (s *UtxoSweeper) UpdateParams(input wire.OutPoint, // - Ensure we don't combine this input with any other unconfirmed inputs that // did not exist in the original sweep transaction, resulting in an invalid // replacement transaction. -func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) ( +func (s *UtxoSweeper) handleUpdateReq(req *updateReq) ( chan Result, error) { // If the UtxoSweeper is already trying to sweep this input, then we can @@ -1445,7 +1105,7 @@ func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) ( // NOTE: The UtxoSweeper is not yet offered time-locked inputs, so the // check for broadcast attempts is redundant at the moment. if pendingInput.publishAttempts > 0 { - pendingInput.minPublishHeight = bestHeight + pendingInput.minPublishHeight = s.currentHeight } resultChan := make(chan Result, 1) @@ -1469,10 +1129,12 @@ func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) ( // - Make handling re-orgs easier. // - Thwart future possible fee sniping attempts. // - Make us blend in with the bitcoind wallet. -func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference, - currentBlockHeight uint32) (*wire.MsgTx, error) { +func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, + feePref FeeEstimateInfo) (*wire.MsgTx, error) { - feePerKw, err := s.cfg.DetermineFeePerKw(s.cfg.FeeEstimator, feePref) + feePerKw, err := feePref.Estimate( + s.cfg.FeeEstimator, s.cfg.MaxFeeRate.FeePerKWeight(), + ) if err != nil { return nil, err } @@ -1484,7 +1146,7 @@ func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference, } tx, _, err := createSweepTx( - inputs, nil, pkScript, currentBlockHeight, feePerKw, + inputs, nil, pkScript, uint32(s.currentHeight), feePerKw, s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer, ) @@ -1506,9 +1168,7 @@ func (s *UtxoSweeper) ListSweeps() ([]chainhash.Hash, error) { // handleNewInput processes a new input by registering spend notification and // scheduling sweeping for it. -func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage, - bestHeight int32) { - +func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) { outpoint := *input.input.OutPoint() pendInput, pending := s.pendingInputs[outpoint] if pending { @@ -1525,7 +1185,7 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage, pendInput = &pendingInput{ listeners: []chan Result{input.resultChan}, Input: input.input, - minPublishHeight: bestHeight, + minPublishHeight: s.currentHeight, params: input.params, } s.pendingInputs[outpoint] = pendInput @@ -1668,19 +1328,19 @@ func (s *UtxoSweeper) handleInputSpent(spend *chainntnfs.SpendDetail) { // handleSweep is called when the ticker fires. It will create clusters and // attempt to create and publish the sweeping transactions. -func (s *UtxoSweeper) handleSweep(bestHeight int32) { +func (s *UtxoSweeper) handleSweep() { // We'll attempt to cluster all of our inputs with similar fee rates. // Before attempting to sweep them, we'll sort them in descending fee // rate order. We do this to ensure any inputs which have had their fee // rate bumped are broadcast first in order enforce the RBF policy. - inputClusters := s.createInputClusters() + inputClusters := s.cfg.Aggregator.ClusterInputs(s.pendingInputs) sort.Slice(inputClusters, func(i, j int) bool { return inputClusters[i].sweepFeeRate > inputClusters[j].sweepFeeRate }) for _, cluster := range inputClusters { - err := s.sweepCluster(cluster, bestHeight) + err := s.sweepCluster(cluster) if err != nil { log.Errorf("input cluster sweep: %v", err) } diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 2003254cd6..c12b04aae5 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1,11 +1,8 @@ package sweep import ( - "errors" "os" - "reflect" "runtime/pprof" - "sort" "testing" "time" @@ -14,7 +11,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" @@ -32,7 +28,7 @@ var ( testMaxInputsPerTx = 3 - defaultFeePref = Params{Fee: FeePreference{ConfTarget: 1}} + defaultFeePref = Params{Fee: FeeEstimateInfo{ConfTarget: 1}} ) type sweeperTestContext struct { @@ -121,6 +117,10 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { estimator := newMockFeeEstimator(10000, chainfee.FeePerKwFloor) + aggregator := NewSimpleUtxoAggregator( + estimator, DefaultMaxFeeRate.FeePerKWeight(), + ) + ctx := &sweeperTestContext{ notifier: notifier, publishChan: backend.publishChan, @@ -149,9 +149,8 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { // Use delta func without random factor. return 1 << uint(attempts-1) }, - MaxFeeRate: DefaultMaxFeeRate, - FeeRateBucketSize: DefaultFeeRateBucketSize, - DetermineFeePerKw: DetermineFeePerKw, + MaxFeeRate: DefaultMaxFeeRate, + Aggregator: aggregator, }) ctx.sweeper.Start() @@ -337,7 +336,9 @@ func TestSuccess(t *testing.T) { ctx := createSweeperTestContext(t) // Sweeping an input without a fee preference should result in an error. - _, err := ctx.sweeper.SweepInput(spendableInputs[0], Params{}) + _, err := ctx.sweeper.SweepInput(spendableInputs[0], Params{ + Fee: &FeeEstimateInfo{}, + }) if err != ErrNoFeePreference { t.Fatalf("expected ErrNoFeePreference, got %v", err) } @@ -383,9 +384,7 @@ func TestDust(t *testing.T) { dustInput := createTestInput(5260, input.CommitmentTimeLock) _, err := ctx.sweeper.SweepInput(&dustInput, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // No sweep transaction is expected now. The sweeper should recognize // that the sweep output will not be relayed and not generate the tx. It @@ -397,18 +396,13 @@ func TestDust(t *testing.T) { largeInput := createTestInput(100000, input.CommitmentTimeLock) _, err = ctx.sweeper.SweepInput(&largeInput, defaultFeePref) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // The second input brings the sweep output above the dust limit. We // expect a sweep tx now. sweepTx := ctx.receiveTx() - if len(sweepTx.TxIn) != 2 { - t.Fatalf("Expected tx to sweep 2 inputs, but contains %v "+ - "inputs instead", len(sweepTx.TxIn)) - } + require.Len(t, sweepTx.TxIn, 2, "unexpected num of tx inputs") ctx.backend.mine() @@ -434,7 +428,7 @@ func TestWalletUtxo(t *testing.T) { _, err := ctx.sweeper.SweepInput( &dustInput, - Params{Fee: FeePreference{FeeRate: chainfee.FeePerKwFloor}}, + Params{Fee: FeeEstimateInfo{FeeRate: chainfee.FeePerKwFloor}}, ) if err != nil { t.Fatal(err) @@ -928,11 +922,11 @@ func TestDifferentFeePreferences(t *testing.T) { // with the higher fee preference, and the last with the lower. We do // this to ensure the sweeper can broadcast distinct transactions for // each sweep with a different fee preference. - lowFeePref := FeePreference{ConfTarget: 12} + lowFeePref := FeeEstimateInfo{ConfTarget: 12} lowFeeRate := chainfee.SatPerKWeight(5000) ctx.estimator.blocksToFee[lowFeePref.ConfTarget] = lowFeeRate - highFeePref := FeePreference{ConfTarget: 6} + highFeePref := FeeEstimateInfo{ConfTarget: 6} highFeeRate := chainfee.SatPerKWeight(10000) ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate @@ -997,12 +991,12 @@ func TestPendingInputs(t *testing.T) { highFeeRate = 10000 ) - lowFeePref := FeePreference{ + lowFeePref := FeeEstimateInfo{ ConfTarget: 12, } ctx.estimator.blocksToFee[lowFeePref.ConfTarget] = lowFeeRate - highFeePref := FeePreference{ + highFeePref := FeeEstimateInfo{ ConfTarget: 6, } ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate @@ -1062,7 +1056,7 @@ func TestPendingInputs(t *testing.T) { func TestBumpFeeRBF(t *testing.T) { ctx := createSweeperTestContext(t) - lowFeePref := FeePreference{ConfTarget: 144} + lowFeePref := FeeEstimateInfo{ConfTarget: 144} lowFeeRate := chainfee.FeePerKwFloor ctx.estimator.blocksToFee[lowFeePref.ConfTarget] = lowFeeRate @@ -1097,12 +1091,14 @@ func TestBumpFeeRBF(t *testing.T) { assertTxFeeRate(t, &lowFeeTx, lowFeeRate, changePk, &input) // We'll then attempt to bump its fee rate. - highFeePref := FeePreference{ConfTarget: 6} + highFeePref := FeeEstimateInfo{ConfTarget: 6} highFeeRate := DefaultMaxFeeRate.FeePerKWeight() ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate // We should expect to see an error if a fee preference isn't provided. - _, err = ctx.sweeper.UpdateParams(*input.OutPoint(), ParamsUpdate{}) + _, err = ctx.sweeper.UpdateParams(*input.OutPoint(), ParamsUpdate{ + Fee: &FeeEstimateInfo{}, + }) if err != ErrNoFeePreference { t.Fatalf("expected ErrNoFeePreference, got %v", err) } @@ -1134,7 +1130,7 @@ func TestExclusiveGroup(t *testing.T) { exclusiveGroup := uint64(1) result, err := ctx.sweeper.SweepInput( spendableInputs[i], Params{ - Fee: FeePreference{ConfTarget: 6}, + Fee: FeeEstimateInfo{ConfTarget: 6}, ExclusiveGroup: &exclusiveGroup, }, ) @@ -1211,7 +1207,7 @@ func TestCpfp(t *testing.T) { }, ) - feePref := FeePreference{ConfTarget: 6} + feePref := FeeEstimateInfo{ConfTarget: 6} result, err := ctx.sweeper.SweepInput( &input, Params{Fee: feePref, Force: true}, ) @@ -1246,224 +1242,6 @@ func TestCpfp(t *testing.T) { ctx.finish(1) } -var ( - testInputsA = pendingInputs{ - wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, - } - - testInputsB = pendingInputs{ - wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, - } - - testInputsC = pendingInputs{ - wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, - wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, - } -) - -// TestMergeClusters check that we properly can merge clusters together, -// according to their required locktime. -func TestMergeClusters(t *testing.T) { - t.Parallel() - - lockTime1 := uint32(100) - lockTime2 := uint32(200) - - testCases := []struct { - name string - a inputCluster - b inputCluster - res []inputCluster - }{ - { - name: "max fee rate", - a: inputCluster{ - sweepFeeRate: 5000, - inputs: testInputsA, - }, - b: inputCluster{ - sweepFeeRate: 7000, - inputs: testInputsB, - }, - res: []inputCluster{ - { - sweepFeeRate: 7000, - inputs: testInputsC, - }, - }, - }, - { - name: "same locktime", - a: inputCluster{ - lockTime: &lockTime1, - sweepFeeRate: 5000, - inputs: testInputsA, - }, - b: inputCluster{ - lockTime: &lockTime1, - sweepFeeRate: 7000, - inputs: testInputsB, - }, - res: []inputCluster{ - { - lockTime: &lockTime1, - sweepFeeRate: 7000, - inputs: testInputsC, - }, - }, - }, - { - name: "diff locktime", - a: inputCluster{ - lockTime: &lockTime1, - sweepFeeRate: 5000, - inputs: testInputsA, - }, - b: inputCluster{ - lockTime: &lockTime2, - sweepFeeRate: 7000, - inputs: testInputsB, - }, - res: []inputCluster{ - { - lockTime: &lockTime1, - sweepFeeRate: 5000, - inputs: testInputsA, - }, - { - lockTime: &lockTime2, - sweepFeeRate: 7000, - inputs: testInputsB, - }, - }, - }, - } - - for _, test := range testCases { - merged := mergeClusters(test.a, test.b) - if !reflect.DeepEqual(merged, test.res) { - t.Fatalf("[%s] unexpected result: %v", - test.name, spew.Sdump(merged)) - } - } -} - -// TestZipClusters tests that we can merge lists of inputs clusters correctly. -func TestZipClusters(t *testing.T) { - t.Parallel() - - createCluster := func(inp pendingInputs, f chainfee.SatPerKWeight) inputCluster { - return inputCluster{ - sweepFeeRate: f, - inputs: inp, - } - } - - testCases := []struct { - name string - as []inputCluster - bs []inputCluster - res []inputCluster - }{ - { - name: "merge A into B", - as: []inputCluster{ - createCluster(testInputsA, 5000), - }, - bs: []inputCluster{ - createCluster(testInputsB, 7000), - }, - res: []inputCluster{ - createCluster(testInputsC, 7000), - }, - }, - { - name: "A can't merge with B", - as: []inputCluster{ - createCluster(testInputsA, 7000), - }, - bs: []inputCluster{ - createCluster(testInputsB, 5000), - }, - res: []inputCluster{ - createCluster(testInputsA, 7000), - createCluster(testInputsB, 5000), - }, - }, - { - name: "empty bs", - as: []inputCluster{ - createCluster(testInputsA, 7000), - }, - bs: []inputCluster{}, - res: []inputCluster{ - createCluster(testInputsA, 7000), - }, - }, - { - name: "empty as", - as: []inputCluster{}, - bs: []inputCluster{ - createCluster(testInputsB, 5000), - }, - res: []inputCluster{ - createCluster(testInputsB, 5000), - }, - }, - - { - name: "zip 3xA into 3xB", - as: []inputCluster{ - createCluster(testInputsA, 5000), - createCluster(testInputsA, 5000), - createCluster(testInputsA, 5000), - }, - bs: []inputCluster{ - createCluster(testInputsB, 7000), - createCluster(testInputsB, 7000), - createCluster(testInputsB, 7000), - }, - res: []inputCluster{ - createCluster(testInputsC, 7000), - createCluster(testInputsC, 7000), - createCluster(testInputsC, 7000), - }, - }, - { - name: "zip A into 3xB", - as: []inputCluster{ - createCluster(testInputsA, 2500), - }, - bs: []inputCluster{ - createCluster(testInputsB, 3000), - createCluster(testInputsB, 2000), - createCluster(testInputsB, 1000), - }, - res: []inputCluster{ - createCluster(testInputsC, 3000), - createCluster(testInputsB, 2000), - createCluster(testInputsB, 1000), - }, - }, - } - - for _, test := range testCases { - zipped := zipClusters(test.as, test.bs) - if !reflect.DeepEqual(zipped, test.res) { - t.Fatalf("[%s] unexpected result: %v", - test.name, spew.Sdump(zipped)) - } - } -} - type testInput struct { *input.BaseInput @@ -1566,7 +1344,7 @@ func TestLockTimes(t *testing.T) { result, err := ctx.sweeper.SweepInput( inp, Params{ - Fee: FeePreference{ConfTarget: 6}, + Fee: FeeEstimateInfo{ConfTarget: 6}, }, ) if err != nil { @@ -1584,7 +1362,7 @@ func TestLockTimes(t *testing.T) { inp := spendableInputs[i+numSweeps*2] result, err := ctx.sweeper.SweepInput( inp, Params{ - Fee: FeePreference{ConfTarget: 6}, + Fee: FeeEstimateInfo{ConfTarget: 6}, }, ) if err != nil { @@ -2029,7 +1807,9 @@ func TestRequiredTxOuts(t *testing.T) { for _, inp := range testCase.inputs { result, err := ctx.sweeper.SweepInput( inp, Params{ - Fee: FeePreference{ConfTarget: 6}, + Fee: FeeEstimateInfo{ + ConfTarget: 6, + }, }, ) if err != nil { @@ -2137,304 +1917,6 @@ func TestSweeperShutdownHandling(t *testing.T) { require.Error(t, err) } -// TestFeeRateForPreference checks `feeRateForPreference` works as expected. -func TestFeeRateForPreference(t *testing.T) { - t.Parallel() - - dummyErr := errors.New("dummy") - - // Create a test sweeper. - s := New(&UtxoSweeperConfig{}) - - // errFeeFunc is a mock over DetermineFeePerKw that always return the - // above dummy error. - errFeeFunc := func(_ chainfee.Estimator, _ FeePreference) ( - chainfee.SatPerKWeight, error) { - - return 0, dummyErr - } - - // Set the relay fee rate to be 1 sat/kw. - s.relayFeeRate = 1 - - // smallFeeFunc is a mock over DetermineFeePerKw that always return a - // fee rate that's below the relayFeeRate. - smallFeeFunc := func(_ chainfee.Estimator, _ FeePreference) ( - chainfee.SatPerKWeight, error) { - - return s.relayFeeRate - 1, nil - } - - // Set the max fee rate to be 1000 sat/vb. - s.cfg.MaxFeeRate = 1000 - - // largeFeeFunc is a mock over DetermineFeePerKw that always return a - // fee rate that's larger than the MaxFeeRate. - largeFeeFunc := func(_ chainfee.Estimator, _ FeePreference) ( - chainfee.SatPerKWeight, error) { - - return s.cfg.MaxFeeRate.FeePerKWeight() + 1, nil - } - - // validFeeRate is used to test the success case. - validFeeRate := (s.cfg.MaxFeeRate.FeePerKWeight() + s.relayFeeRate) / 2 - - // normalFeeFunc is a mock over DetermineFeePerKw that always return a - // fee rate that's within the range. - normalFeeFunc := func(_ chainfee.Estimator, _ FeePreference) ( - chainfee.SatPerKWeight, error) { - - return validFeeRate, nil - } - - testCases := []struct { - name string - feePref FeePreference - determineFeePerKw feeDeterminer - expectedFeeRate chainfee.SatPerKWeight - expectedErr error - }{ - { - // When the fee preference is empty, we should see an - // error. - name: "empty fee preference", - feePref: FeePreference{}, - expectedErr: ErrNoFeePreference, - }, - { - // When an error is returned from the fee determiner, - // we should return it. - name: "error from DetermineFeePerKw", - feePref: FeePreference{FeeRate: 1}, - determineFeePerKw: errFeeFunc, - expectedErr: dummyErr, - }, - { - // When DetermineFeePerKw gives a too small value, we - // should return an error. - name: "fee rate below relay fee rate", - feePref: FeePreference{FeeRate: 1}, - determineFeePerKw: smallFeeFunc, - expectedErr: ErrFeePreferenceTooLow, - }, - { - // When DetermineFeePerKw gives a too large value, we - // should cap it at the max fee rate. - name: "fee rate above max fee rate", - feePref: FeePreference{FeeRate: 1}, - determineFeePerKw: largeFeeFunc, - expectedFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(), - }, - { - // When DetermineFeePerKw gives a sane fee rate, we - // should return it without any error. - name: "success", - feePref: FeePreference{FeeRate: 1}, - determineFeePerKw: normalFeeFunc, - expectedFeeRate: validFeeRate, - }, - } - - //nolint:paralleltest - for _, tc := range testCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - // Attach the mocked method. - s.cfg.DetermineFeePerKw = tc.determineFeePerKw - - // Call the function under test. - feerate, err := s.feeRateForPreference(tc.feePref) - - // Assert the expected feerate. - require.Equal(t, tc.expectedFeeRate, feerate) - - // Assert the expected error. - require.ErrorIs(t, err, tc.expectedErr) - }) - } -} - -// TestClusterByLockTime tests the method clusterByLockTime works as expected. -func TestClusterByLockTime(t *testing.T) { - t.Parallel() - - // Create a test param with a dummy fee preference. This is needed so - // `feeRateForPreference` won't throw an error. - param := Params{Fee: FeePreference{ConfTarget: 1}} - - // We begin the test by creating three clusters of inputs, the first - // cluster has a locktime of 1, the second has a locktime of 2, and the - // final has no locktime. - lockTime1 := uint32(1) - lockTime2 := uint32(2) - - // Create cluster one, which has a locktime of 1. - input1LockTime1 := &input.MockInput{} - input2LockTime1 := &input.MockInput{} - input1LockTime1.On("RequiredLockTime").Return(lockTime1, true) - input2LockTime1.On("RequiredLockTime").Return(lockTime1, true) - - // Create cluster two, which has a locktime of 2. - input3LockTime2 := &input.MockInput{} - input4LockTime2 := &input.MockInput{} - input3LockTime2.On("RequiredLockTime").Return(lockTime2, true) - input4LockTime2.On("RequiredLockTime").Return(lockTime2, true) - - // Create cluster three, which has no locktime. - input5NoLockTime := &input.MockInput{} - input6NoLockTime := &input.MockInput{} - input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false) - input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false) - - // With the inner Input being mocked, we can now create the pending - // inputs. - input1 := &pendingInput{Input: input1LockTime1, params: param} - input2 := &pendingInput{Input: input2LockTime1, params: param} - input3 := &pendingInput{Input: input3LockTime2, params: param} - input4 := &pendingInput{Input: input4LockTime2, params: param} - input5 := &pendingInput{Input: input5NoLockTime, params: param} - input6 := &pendingInput{Input: input6NoLockTime, params: param} - - // Create the pending inputs map, which will be passed to the method - // under test. - // - // NOTE: we don't care the actual outpoint values as long as they are - // unique. - inputs := pendingInputs{ - wire.OutPoint{Index: 1}: input1, - wire.OutPoint{Index: 2}: input2, - wire.OutPoint{Index: 3}: input3, - wire.OutPoint{Index: 4}: input4, - wire.OutPoint{Index: 5}: input5, - wire.OutPoint{Index: 6}: input6, - } - - // Create expected clusters so we can shorten the line length in the - // test cases below. - cluster1 := pendingInputs{ - wire.OutPoint{Index: 1}: input1, - wire.OutPoint{Index: 2}: input2, - } - cluster2 := pendingInputs{ - wire.OutPoint{Index: 3}: input3, - wire.OutPoint{Index: 4}: input4, - } - - // cluster3 should be the remaining inputs since they don't have - // locktime. - cluster3 := pendingInputs{ - wire.OutPoint{Index: 5}: input5, - wire.OutPoint{Index: 6}: input6, - } - - // Set the min fee rate to be 1000 sat/kw. - const minFeeRate = chainfee.SatPerKWeight(1000) - - // Create a test sweeper. - s := New(&UtxoSweeperConfig{ - MaxFeeRate: minFeeRate.FeePerVByte() * 10, - }) - - // Set the relay fee to be the minFeeRate. Any fee rate below the - // minFeeRate will cause an error to be returned. - s.relayFeeRate = minFeeRate - - // applyFeeRate takes a testing fee rate and makes a mocker over - // DetermineFeePerKw that always return the testing fee rate. This - // mocked method is then attached to the sweeper. - applyFeeRate := func(feeRate chainfee.SatPerKWeight) { - mockFeeFunc := func(_ chainfee.Estimator, _ FeePreference) ( - chainfee.SatPerKWeight, error) { - - return feeRate, nil - } - - s.cfg.DetermineFeePerKw = mockFeeFunc - } - - testCases := []struct { - name string - testFeeRate chainfee.SatPerKWeight - expectedClusters []inputCluster - expectedRemainingInputs pendingInputs - }{ - { - // Test a successful case where the locktime clusters - // are created and the no-locktime cluster is returned - // as the remaining inputs. - name: "successfully create clusters", - // Use a fee rate above the min value so we don't hit - // an error when performing fee estimation. - // - // TODO(yy): we should customize the returned fee rate - // for each input to further test the averaging logic. - // Or we can split the method into two, one for - // grouping the clusters and the other for averaging - // the fee rates so it's easier to be tested. - testFeeRate: minFeeRate + 1, - expectedClusters: []inputCluster{ - { - lockTime: &lockTime1, - sweepFeeRate: minFeeRate + 1, - inputs: cluster1, - }, - { - lockTime: &lockTime2, - sweepFeeRate: minFeeRate + 1, - inputs: cluster2, - }, - }, - expectedRemainingInputs: cluster3, - }, - { - // Test that when the input is skipped when the fee - // estimation returns an error. - name: "error from fee estimation", - // Use a fee rate below the min value so we hit an - // error when performing fee estimation. - testFeeRate: minFeeRate - 1, - expectedClusters: []inputCluster{}, - // Remaining inputs should stay untouched. - expectedRemainingInputs: cluster3, - }, - } - - //nolint:paralleltest - for _, tc := range testCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - // Apply the test fee rate so `feeRateForPreference` is - // mocked to return the specified value. - applyFeeRate(tc.testFeeRate) - - // Call the method under test. - clusters, remainingInputs := s.clusterByLockTime(inputs) - - // Sort by locktime as the order is not guaranteed. - sort.Slice(clusters, func(i, j int) bool { - return *clusters[i].lockTime < - *clusters[j].lockTime - }) - - // Validate the values are returned as expected. - require.Equal(t, tc.expectedClusters, clusters) - require.Equal(t, tc.expectedRemainingInputs, - remainingInputs, - ) - - // Assert the mocked methods are called as expected. - input1LockTime1.AssertExpectations(t) - input2LockTime1.AssertExpectations(t) - input3LockTime2.AssertExpectations(t) - input4LockTime2.AssertExpectations(t) - input5NoLockTime.AssertExpectations(t) - input6NoLockTime.AssertExpectations(t) - }) - } -} - // TestGetInputLists checks that the expected input sets are returned based on // whether there are retried inputs or not. func TestGetInputLists(t *testing.T) { @@ -2442,7 +1924,7 @@ func TestGetInputLists(t *testing.T) { // Create a test param with a dummy fee preference. This is needed so // `feeRateForPreference` won't throw an error. - param := Params{Fee: FeePreference{ConfTarget: 1}} + param := Params{Fee: FeeEstimateInfo{ConfTarget: 1}} // Create a mock input and mock all the methods used in this test. testInput := &input.MockInput{} @@ -2530,7 +2012,7 @@ func TestGetInputLists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - allSets, newSets, err := s.getInputLists(tc.cluster, 0) + allSets, newSets, err := s.getInputLists(tc.cluster) require.NoError(t, err) if tc.expectNilNewSet { diff --git a/sweep/walletsweep.go b/sweep/walletsweep.go index d3c5365c0a..90d9e351d1 100644 --- a/sweep/walletsweep.go +++ b/sweep/walletsweep.go @@ -1,6 +1,7 @@ package sweep import ( + "errors" "fmt" "math" @@ -19,9 +20,37 @@ const ( defaultNumBlocksEstimate = 6 ) -// FeePreference allows callers to express their time value for inclusion of a -// transaction into a block via either a confirmation target, or a fee rate. -type FeePreference struct { +var ( + // ErrNoFeePreference is returned when we attempt to satisfy a sweep + // request from a client whom did not specify a fee preference. + ErrNoFeePreference = errors.New("no fee preference specified") + + // ErrFeePreferenceConflict is returned when both a fee rate and a conf + // target is set for a fee preference. + ErrFeePreferenceConflict = errors.New("fee preference conflict") +) + +// FeePreference defines an interface that allows the caller to specify how the +// fee rate should be handled. Depending on the implementation, the fee rate +// can either be specified directly, or via a conf target which relies on the +// chain backend(`bitcoind`) to give a fee estimation, or a customized fee +// function which handles fee calculation based on the specified +// urgency(deadline). +type FeePreference interface { + // String returns a human-readable string of the fee preference. + String() string + + // Estimate takes a fee estimator and a max allowed fee rate and + // returns a fee rate for the given fee preference. It ensures that the + // fee rate respects the bounds of the relay fee and the specified max + // fee rates. + Estimate(chainfee.Estimator, + chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) +} + +// FeeEstimateInfo allows callers to express their time value for inclusion of +// a transaction into a block via either a confirmation target, or a fee rate. +type FeeEstimateInfo struct { // ConfTarget if non-zero, signals a fee preference expressed in the // number of desired blocks between first broadcast, and confirmation. ConfTarget uint32 @@ -31,84 +60,91 @@ type FeePreference struct { FeeRate chainfee.SatPerKWeight } +// Compile-time constraint to ensure FeeEstimateInfo implements FeePreference. +var _ FeePreference = (*FeeEstimateInfo)(nil) + // String returns a human-readable string of the fee preference. -func (p FeePreference) String() string { - if p.ConfTarget != 0 { - return fmt.Sprintf("%v blocks", p.ConfTarget) +func (f FeeEstimateInfo) String() string { + if f.ConfTarget != 0 { + return fmt.Sprintf("%v blocks", f.ConfTarget) } - return p.FeeRate.String() + + return f.FeeRate.String() } -// DetermineFeePerKw will determine the fee in sat/kw that should be paid given -// an estimator, a confirmation target, and a manual value for sat/byte. A -// value is chosen based on the two free parameters as one, or both of them can -// be zero. -func DetermineFeePerKw(feeEstimator chainfee.Estimator, - feePref FeePreference) (chainfee.SatPerKWeight, error) { +// Estimate returns a fee rate for the given fee preference. It ensures that +// the fee rate respects the bounds of the relay fee and the max fee rates, if +// specified. +func (f FeeEstimateInfo) Estimate(estimator chainfee.Estimator, + maxFeeRate chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) { + + var ( + feeRate chainfee.SatPerKWeight + err error + ) switch { + // Ensure a type of fee preference is specified to prevent using a + // default below. + case f.FeeRate == 0 && f.ConfTarget == 0: + return 0, ErrNoFeePreference + // If both values are set, then we'll return an error as we require a // strict directive. - case feePref.FeeRate != 0 && feePref.ConfTarget != 0: - return 0, fmt.Errorf("only FeeRate or ConfTarget should " + - "be set for FeePreferences") + case f.FeeRate != 0 && f.ConfTarget != 0: + return 0, ErrFeePreferenceConflict // If the target number of confirmations is set, then we'll use that to // consult our fee estimator for an adequate fee. - case feePref.ConfTarget != 0: - feePerKw, err := feeEstimator.EstimateFeePerKW( - uint32(feePref.ConfTarget), - ) + case f.ConfTarget != 0: + feeRate, err = estimator.EstimateFeePerKW((f.ConfTarget)) if err != nil { return 0, fmt.Errorf("unable to query fee "+ - "estimator: %v", err) + "estimator: %w", err) } - return feePerKw, nil - - // If a manual sat/byte fee rate is set, then we'll use that directly. + // If a manual sat/kw fee rate is set, then we'll use that directly. // We'll need to convert it to sat/kw as this is what we use // internally. - case feePref.FeeRate != 0: - feePerKW := feePref.FeeRate + case f.FeeRate != 0: + feeRate = f.FeeRate // Because the user can specify 1 sat/vByte on the RPC // interface, which corresponds to 250 sat/kw, we need to bump // that to the minimum "safe" fee rate which is 253 sat/kw. - if feePerKW == chainfee.AbsoluteFeePerKwFloor { + if feeRate == chainfee.AbsoluteFeePerKwFloor { log.Infof("Manual fee rate input of %d sat/kw is "+ - "too low, using %d sat/kw instead", feePerKW, + "too low, using %d sat/kw instead", feeRate, chainfee.FeePerKwFloor) - feePerKW = chainfee.FeePerKwFloor - } - // If that bumped fee rate of at least 253 sat/kw is still lower - // than the relay fee rate, we return an error to let the user - // know. Note that "Relay fee rate" may mean slightly different - // things depending on the backend. For bitcoind, it is - // effectively max(relay fee, min mempool fee). - minFeePerKW := feeEstimator.RelayFeePerKW() - if feePerKW < minFeePerKW { - return 0, fmt.Errorf("manual fee rate input of %d "+ - "sat/kw is too low to be accepted into the "+ - "mempool or relayed to the network", feePerKW) + feeRate = chainfee.FeePerKwFloor } + } - return feePerKW, nil + // Get the relay fee as the min fee rate. + minFeeRate := estimator.RelayFeePerKW() + + // If that bumped fee rate of at least 253 sat/kw is still lower than + // the relay fee rate, we return an error to let the user know. Note + // that "Relay fee rate" may mean slightly different things depending + // on the backend. For bitcoind, it is effectively max(relay fee, min + // mempool fee). + if feeRate < minFeeRate { + return 0, fmt.Errorf("%w: got %v, minimum is %v", + ErrFeePreferenceTooLow, feeRate, minFeeRate) + } - // Otherwise, we'll attempt a relaxed confirmation target for the - // transaction - default: - feePerKw, err := feeEstimator.EstimateFeePerKW( - defaultNumBlocksEstimate, - ) - if err != nil { - return 0, fmt.Errorf("unable to query fee estimator: "+ - "%v", err) - } + // If a maxFeeRate is specified and the estimated fee rate is above the + // maximum allowed fee rate, default to the max fee rate. + if maxFeeRate != 0 && feeRate > maxFeeRate { + log.Warnf("Estimated fee rate %v exceeds max allowed fee "+ + "rate %v, using max fee rate instead", feeRate, + maxFeeRate) - return feePerKw, nil + return maxFeeRate, nil } + + return feeRate, nil } // UtxoSource is an interface that allows a caller to access a source of UTXOs diff --git a/sweep/walletsweep_test.go b/sweep/walletsweep_test.go index 4a7ceec7a7..67c390fcac 100644 --- a/sweep/walletsweep_test.go +++ b/sweep/walletsweep_test.go @@ -2,6 +2,7 @@ package sweep import ( "bytes" + "errors" "fmt" "testing" @@ -15,106 +16,135 @@ import ( "github.com/stretchr/testify/require" ) -// TestDetermineFeePerKw tests that given a fee preference, the -// DetermineFeePerKw will properly map it to a concrete fee in sat/kw. -func TestDetermineFeePerKw(t *testing.T) { +// TestFeeEstimateInfo checks `Estimate` method works as expected. +func TestFeeEstimateInfo(t *testing.T) { t.Parallel() - defaultFee := chainfee.SatPerKWeight(999) - relayFee := chainfee.SatPerKWeight(300) + dummyErr := errors.New("dummy") - feeEstimator := newMockFeeEstimator(defaultFee, relayFee) + const ( + // Set the relay fee rate to be 10 sat/kw. + relayFeeRate = 10 - // We'll populate two items in the internal map which is used to query - // a fee based on a confirmation target: the default conf target, and - // an arbitrary conf target. We'll ensure below that both of these are - // properly - feeEstimator.blocksToFee[50] = 300 - feeEstimator.blocksToFee[defaultNumBlocksEstimate] = 1000 + // Set the max fee rate to be 1000 sat/vb. + maxFeeRate = 1000 - testCases := []struct { - // feePref is the target fee preference for this case. - feePref FeePreference + // Create a valid fee rate to test the success case. + validFeeRate = (relayFeeRate + maxFeeRate) / 2 + + // Set the test conf target to be 1. + conf uint32 = 1 + ) - // fee is the value the DetermineFeePerKw should return given - // the FeePreference above - fee chainfee.SatPerKWeight + // Create a mock fee estimator. + estimator := &chainfee.MockEstimator{} - // fail determines if this test case should fail or not. - fail bool + testCases := []struct { + name string + setupMocker func() + feePref FeeEstimateInfo + expectedFeeRate chainfee.SatPerKWeight + expectedErr error }{ - // A fee rate below the floor should error out. { - feePref: FeePreference{ - FeeRate: chainfee.SatPerKWeight(99), - }, - fail: true, + // When the fee preference is empty, we should see an + // error. + name: "empty fee preference", + feePref: FeeEstimateInfo{}, + expectedErr: ErrNoFeePreference, }, - - // A fee rate below the relay fee should error out. { - feePref: FeePreference{ - FeeRate: chainfee.SatPerKWeight(299), + // When the fee preference has conflicts, we should see + // an error. + name: "conflict fee preference", + feePref: FeeEstimateInfo{ + FeeRate: validFeeRate, + ConfTarget: conf, }, - fail: true, + expectedErr: ErrFeePreferenceConflict, }, - - // A fee rate above the floor, should pass through and return - // the target fee rate. { - feePref: FeePreference{ - FeeRate: 900, + // When an error is returned from the fee estimator, we + // should return it. + name: "error from Estimator", + setupMocker: func() { + estimator.On("EstimateFeePerKW", conf).Return( + chainfee.SatPerKWeight(0), dummyErr, + ).Once() }, - fee: 900, + feePref: FeeEstimateInfo{ConfTarget: conf}, + expectedErr: dummyErr, }, - - // A specified confirmation target should cause the function to - // query the estimator which will return our value specified - // above. { - feePref: FeePreference{ - ConfTarget: 50, + // When FeeEstimateInfo uses a too small value, we + // should return an error. + name: "fee rate below relay fee rate", + setupMocker: func() { + // Mock the relay fee rate. + estimator.On("RelayFeePerKW").Return( + chainfee.SatPerKWeight(relayFeeRate), + ).Once() }, - fee: 300, + feePref: FeeEstimateInfo{FeeRate: relayFeeRate - 1}, + expectedErr: ErrFeePreferenceTooLow, }, - - // If the caller doesn't specify any values at all, then we - // should query for the default conf target. { - feePref: FeePreference{}, - fee: 1000, + // When FeeEstimateInfo gives a too large value, we + // should cap it at the max fee rate. + name: "fee rate above max fee rate", + setupMocker: func() { + // Mock the relay fee rate. + estimator.On("RelayFeePerKW").Return( + chainfee.SatPerKWeight(relayFeeRate), + ).Once() + }, + feePref: FeeEstimateInfo{ + FeeRate: maxFeeRate + 1, + }, + expectedFeeRate: maxFeeRate, }, - - // Both conf target and fee rate are set, we should return with - // an error. { - feePref: FeePreference{ - ConfTarget: 50, - FeeRate: 90000, + // When Estimator gives a sane fee rate, we should + // return it without any error. + name: "success", + setupMocker: func() { + estimator.On("EstimateFeePerKW", conf).Return( + chainfee.SatPerKWeight(validFeeRate), + nil).Once() + + // Mock the relay fee rate. + estimator.On("RelayFeePerKW").Return( + chainfee.SatPerKWeight(relayFeeRate), + ).Once() }, - fee: 300, - fail: true, + feePref: FeeEstimateInfo{ConfTarget: conf}, + expectedFeeRate: validFeeRate, }, } - for i, testCase := range testCases { - targetFee, err := DetermineFeePerKw( - feeEstimator, testCase.feePref, - ) - switch { - case testCase.fail && err != nil: - continue - - case testCase.fail && err == nil: - t.Fatalf("expected failure for #%v", i) - - case !testCase.fail && err != nil: - t.Fatalf("unable to estimate fee; %v", err) - } - if targetFee != testCase.fee { - t.Fatalf("#%v: wrong fee: expected %v got %v", i, - testCase.fee, targetFee) - } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Setup the mockers if specified. + if tc.setupMocker != nil { + tc.setupMocker() + } + + // Call the function under test. + feerate, err := tc.feePref.Estimate( + estimator, maxFeeRate, + ) + + // Assert the expected error. + require.ErrorIs(t, err, tc.expectedErr) + + // Assert the expected feerate. + require.Equal(t, tc.expectedFeeRate, feerate) + + // Assert the mockers. + estimator.AssertExpectations(t) + }) } }