Skip to content

Commit 9f589b0

Browse files
committed
txbatcher: use lock, rename private methods, add type hints
1 parent cebc5af commit 9f589b0

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

electrum/txbatcher.py

+27-25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
2+
import threading
23
import copy
34

5+
from typing import Dict, Sequence
46
from . import util
57
from .logging import Logger
68
from .util import log_exceptions
@@ -74,6 +76,8 @@ def __init__(self, wallet):
7476
self.wallet = wallet
7577
self.config = wallet.config
7678
self.db = wallet.db
79+
self.lock = threading.Lock()
80+
# fixme: not robust to client restart, because we do not persist batch_payments
7781
self.batch_payments = [] # list of payments we need to make
7882
self.batch_inputs = {} # list of inputs we need to sweep
7983

@@ -96,7 +100,8 @@ def __init__(self, wallet):
96100

97101
def add_batch_payment(self, output: 'PartialTxOutput'):
98102
# todo: maybe we should raise NotEnoughFunds here
99-
self.batch_payments.append(output)
103+
with self.lock:
104+
self.batch_payments.append(output)
100105

101106
def add_sweep_info(self, sweep_info: 'SweepInfo'):
102107
txin = sweep_info.txin
@@ -123,7 +128,6 @@ def add_sweep_info(self, sweep_info: 'SweepInfo'):
123128
base_txin.witness_script = txin.witness_script
124129
base_txin.script_sig = txin.script_sig
125130

126-
127131
def get_base_tx(self) -> Optional[Transaction]:
128132
if self._base_tx:
129133
return self._base_tx
@@ -140,7 +144,7 @@ def get_base_tx(self) -> Optional[Transaction]:
140144
base_tx.add_info_from_wallet(self.wallet) # needed for txid
141145
return base_tx
142146

143-
def find_confirmed_base_tx(self) -> Optional[Transaction]:
147+
def _find_confirmed_base_tx(self) -> Optional[Transaction]:
144148
for txid in self.batch_txids:
145149
tx_mined_status = self.wallet.adb.get_tx_height(txid)
146150
if tx_mined_status.conf > 0:
@@ -149,8 +153,7 @@ def find_confirmed_base_tx(self) -> Optional[Transaction]:
149153
tx.add_info_from_wallet(self.wallet) # needed for txid
150154
return tx
151155

152-
def to_pay_after(self, tx):
153-
# fixme: not robust to client restart, because we do not persist batch_payments
156+
def _to_pay_after(self, tx) -> Sequence[PartialTxOutput]:
154157
if not tx:
155158
return self.batch_payments
156159
to_pay = []
@@ -162,7 +165,7 @@ def to_pay_after(self, tx):
162165
outputs.remove(x)
163166
return to_pay
164167

165-
def to_sweep_after(self, tx):
168+
def _to_sweep_after(self, tx) -> Dict[str, SweepInfo]:
166169
tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set()
167170
result = []
168171
for k,v in self.batch_inputs.items():
@@ -179,7 +182,7 @@ def to_sweep_after(self, tx):
179182
result.append((k,v))
180183
return dict(result)
181184

182-
def should_bump_fee(self, base_tx):
185+
def _should_bump_fee(self, base_tx) -> bool:
183186
if base_tx is None:
184187
return False
185188
base_tx_fee = base_tx.get_fee()
@@ -196,30 +199,30 @@ async def run(self):
196199
password = self.wallet.get_unlocked_password()
197200
if self.wallet.has_keystore_encryption() and not password:
198201
continue
199-
await self.maybe_broadcast_legacy_htlc_txs()
200-
tx = self.find_confirmed_base_tx()
202+
await self._maybe_broadcast_legacy_htlc_txs()
203+
tx = self._find_confirmed_base_tx()
201204
if tx:
202205
self.logger.info(f'base tx confirmed {tx.txid()}')
203-
self.clear_batch_processing(tx)
204-
self.start_new_batch(tx)
206+
self._clear_batch_processing(tx)
207+
self._start_new_batch(tx)
205208
base_tx = self.get_base_tx()
206-
to_pay = self.to_pay_after(base_tx)
207-
to_sweep = self.to_sweep_after(base_tx)
209+
to_pay = self._to_pay_after(base_tx)
210+
to_sweep = self._to_sweep_after(base_tx)
208211
to_sweep_now = {}
209212
for k, v in to_sweep.items():
210-
can_broadcast, wanted_height = self.can_broadcast(v, base_tx)
213+
can_broadcast, wanted_height = self._can_broadcast(v, base_tx)
211214
if can_broadcast:
212215
to_sweep_now[k] = v
213216
else:
214217
self.wallet.add_future_tx(v, wanted_height)
215-
if not to_pay and not to_sweep_now and not self.should_bump_fee(base_tx):
218+
if not to_pay and not to_sweep_now and not self._should_bump_fee(base_tx):
216219
continue
217220
try:
218-
tx = self.create_batch_tx(base_tx, to_sweep_now, to_pay, password)
221+
tx = self._create_batch_tx(base_tx, to_sweep_now, to_pay, password)
219222
except Exception as e:
220223
self.logger.exception(f'Cannot create batch transaction: {repr(e)}')
221224
if base_tx:
222-
self.start_new_batch(base_tx)
225+
self._start_new_batch(base_tx)
223226
continue
224227
await asyncio.sleep(self.RETRY_DELAY)
225228
continue
@@ -242,7 +245,7 @@ async def run(self):
242245
self.logger.info(f'starting new batch because could not broadcast')
243246
self.start_new_batch(base_tx)
244247

245-
def create_batch_tx(self, base_tx, to_sweep, to_pay, password):
248+
def _create_batch_tx(self, base_tx, to_sweep, to_pay, password):
246249
self.logger.info(f'to_sweep: {list(to_sweep.keys())}')
247250
self.logger.info(f'to_pay: {to_pay}')
248251
inputs = []
@@ -260,7 +263,6 @@ def create_batch_tx(self, base_tx, to_sweep, to_pay, password):
260263
self.logger.info(f'locktime: {locktime}')
261264
outputs += to_pay
262265
inputs += self.get_change_inputs(self._parent_tx) if self._parent_tx else []
263-
264266
tx = self.wallet.create_transaction(
265267
base_tx=base_tx,
266268
inputs=inputs,
@@ -275,18 +277,18 @@ def create_batch_tx(self, base_tx, to_sweep, to_pay, password):
275277
assert tx.is_complete()
276278
return tx
277279

278-
def clear_batch_processing(self, tx):
280+
def _clear_batch_processing(self, tx):
279281
# this ensure that we can accept an input again
280282
# if the spending tx is removed from the blockchain
281283
# fixme: what if there are several batches?
282284
for txin in tx.inputs():
283285
if txin.prevout in self.batch_processing:
284286
self.batch_processing.remove(txin.prevout)
285287

286-
def start_new_batch(self, tx):
288+
def _start_new_batch(self, tx):
287289
use_change = tx and tx.has_change() and any([txout in self.batch_payments for txout in tx.outputs()])
288-
self.batch_payments = self.to_pay_after(tx)
289-
self.batch_inputs = self.to_sweep_after(tx)
290+
self.batch_payments = self._to_pay_after(tx)
291+
self.batch_inputs = self._to_sweep_after(tx)
290292
self.batch_txids.clear()
291293
self._base_tx = None
292294
self._parent_tx = tx if use_change else None
@@ -301,7 +303,7 @@ def get_change_inputs(self, parent_tx):
301303
txin.nsequence = 0xffffffff - 2
302304
return inputs
303305

304-
def can_broadcast(self, sweep_info: 'SweepInfo', base_tx):
306+
def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx):
305307
prevout = sweep_info.txin.prevout.to_str()
306308
name = sweep_info.name
307309
prev_txid, index = prevout.split(':')
@@ -331,7 +333,7 @@ def can_broadcast(self, sweep_info: 'SweepInfo', base_tx):
331333
wanted_height = prev_height
332334
return can_broadcast, wanted_height
333335

334-
async def maybe_broadcast_legacy_htlc_txs(self):
336+
async def _maybe_broadcast_legacy_htlc_txs(self):
335337
""" pre-anchor htlc txs cannot be batched """
336338
for sweep_info in list(self.batch_inputs.values()):
337339
if sweep_info.name == 'first-stage-htlc':

tests/test_txbatcher.py

-2
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ async def test_batch_payments(self, mock_save_db):
118118
assert wallet.adb.get_transaction(tx1.txid()) is not None
119119
assert wallet.adb.get_transaction(tx1_prime.txid()) is None
120120
# txbatcher creates tx2
121-
self.logger.info(f'to pay after {wallet.txbatcher.to_pay_after(tx1)}')
122-
self.logger.info(f'{tx_mined_status}')
123121
await self.network._tx_event.wait()
124122
tx2 = wallet.txbatcher.get_base_tx()
125123
assert output1 in tx1.outputs()

0 commit comments

Comments
 (0)