diff --git a/bittensor_cli/src/commands/stake/add.py b/bittensor_cli/src/commands/stake/add.py index f165bf3d6..097637eae 100644 --- a/bittensor_cli/src/commands/stake/add.py +++ b/bittensor_cli/src/commands/stake/add.py @@ -1,5 +1,6 @@ import asyncio from collections import defaultdict +from dataclasses import dataclass from functools import partial from typing import TYPE_CHECKING, Optional @@ -30,9 +31,26 @@ from bittensor_wallet import Wallet if TYPE_CHECKING: + from bittensor_cli.src.bittensor.chain_data import DynamicInfo from bittensor_cli.src.bittensor.subtensor_interface import SubtensorInterface +@dataclass(frozen=True) +class StakeOperationTarget: + staking_address: str + netuid: int + subnet_info: "DynamicInfo" + current_stake_balance: Balance + + +@dataclass(frozen=True) +class StakeOperationPlan: + target: StakeOperationTarget + amount_to_stake: Balance + extrinsic_fee: Balance + price_with_tolerance: Optional[Balance] + + # Command async def stake_add( wallet: Wallet, @@ -105,7 +123,7 @@ async def get_stake_extrinsic_fee( if safe_staking_: call_params.update( { - "limit_price": price_limit, + "limit_price": price_limit.rao if price_limit else None, "allow_partial": allow_partial_stake, } ) @@ -116,6 +134,118 @@ async def get_stake_extrinsic_fee( ) return await subtensor.get_extrinsic_fee(call, wallet.coldkeypub, proxy=proxy) + def _build_price_with_tolerance( + operation_target: StakeOperationTarget, + ) -> Optional[Balance]: + if not safe_staking: + return None + + current_price_float = float(operation_target.subnet_info.price.tao) + if operation_target.subnet_info.is_dynamic: + return Balance.from_tao(current_price_float * (1 + rate_tolerance)) + return Balance.from_rao(1) + + def _validate_stake_plan_affordability( + operation_plans: list[StakeOperationPlan], + current_balance: Balance, + ) -> bool: + if proxy: + return True + + total_required = Balance.from_rao(0) + for operation_plan in operation_plans: + total_required += ( + operation_plan.amount_to_stake + operation_plan.extrinsic_fee + ) + + if total_required > current_balance: + print_error( + f"Not enough balance to cover stake operations and extrinsic fees:\n" + f" required: {total_required} > balance: {current_balance}" + ) + return False + + return True + + async def build_fee_aware_stake_plan( + operation_targets: list[StakeOperationTarget], + requested_amounts: list[Balance], + current_balance: Balance, + ) -> Optional[list[StakeOperationPlan]]: + if not operation_targets: + return [] + + price_limits = [ + _build_price_with_tolerance(operation_target) + for operation_target in operation_targets + ] + planned_amounts = list(requested_amounts) + + if stake_all: + per_operation_amount = current_balance / len(operation_targets) + + if not proxy: + for _ in range(3): + fee_estimates = await asyncio.gather( + *[ + get_stake_extrinsic_fee( + netuid_=operation_target.netuid, + amount_=per_operation_amount, + staking_address_=operation_target.staking_address, + safe_staking_=safe_staking, + price_limit=price_limits[idx], + ) + for idx, operation_target in enumerate(operation_targets) + ] + ) + + total_fee = Balance.from_rao(0) + for fee_estimate in fee_estimates: + total_fee += fee_estimate + + available_for_staking = current_balance - total_fee + if available_for_staking <= Balance.from_rao(0): + print_error( + "Not enough balance to cover extrinsic fees for stake-all operations." + ) + return None + + adjusted_amount = available_for_staking / len(operation_targets) + if adjusted_amount == per_operation_amount: + break + + per_operation_amount = adjusted_amount + + planned_amounts = [per_operation_amount for _ in operation_targets] + + fee_estimates = await asyncio.gather( + *[ + get_stake_extrinsic_fee( + netuid_=operation_target.netuid, + amount_=planned_amounts[idx], + staking_address_=operation_target.staking_address, + safe_staking_=safe_staking, + price_limit=price_limits[idx], + ) + for idx, operation_target in enumerate(operation_targets) + ] + ) + + operation_plans = [ + StakeOperationPlan( + target=operation_target, + amount_to_stake=planned_amounts[idx], + extrinsic_fee=fee_estimates[idx], + price_with_tolerance=price_limits[idx], + ) + for idx, operation_target in enumerate(operation_targets) + ] + + if not _validate_stake_plan_affordability(operation_plans, current_balance): + return None + + return operation_plans + async def safe_stake_extrinsic( netuid_: int, amount_: Balance, @@ -339,7 +469,7 @@ async def stake_extrinsic( ) # Determine the amount we are staking. - operation_targets = [] + operation_targets: list[StakeOperationTarget] = [] for hotkey in hotkeys_to_stake_to: for netuid in netuids: # Check that the subnet exists. @@ -348,7 +478,12 @@ async def stake_extrinsic( print_error(f"Subnet with netuid: {netuid} does not exist.") continue operation_targets.append( - (hotkey, netuid, subnet_info, hotkey_stake_map[hotkey[1]][netuid]) + StakeOperationTarget( + staking_address=hotkey[1], + netuid=netuid, + subnet_info=subnet_info, + current_stake_balance=hotkey_stake_map[hotkey[1]][netuid], + ) ) if stake_all and not operation_targets: @@ -359,31 +494,52 @@ async def stake_extrinsic( operations = [] remaining_wallet_balance = current_wallet_balance max_slippage = 0.0 + requested_amounts: list[Balance] = [] - for hotkey, netuid, subnet_info, current_stake_balance in operation_targets: - staking_address = hotkey[1] - + for operation_target in operation_targets: # Get the amount. amount_to_stake = Balance(0) if amount: amount_to_stake = Balance.from_tao(amount) elif stake_all: - amount_to_stake = current_wallet_balance / len(operation_targets) + amount_to_stake = Balance.from_rao(0) elif not amount: amount_to_stake, _ = _prompt_stake_amount( current_balance=remaining_wallet_balance, - netuid=netuid, + netuid=operation_target.netuid, action_name="stake", ) - # Check enough to stake. - if amount_to_stake > remaining_wallet_balance: + # Check enough to stake, excluding extrinsic fees (validated after fee-aware planning). + if not stake_all and amount_to_stake > remaining_wallet_balance: print_error( f"Not enough stake:[bold white]\n wallet balance:{remaining_wallet_balance} < " f"staking amount: {amount_to_stake}[/bold white]" ) return - remaining_wallet_balance -= amount_to_stake + + if not stake_all: + remaining_wallet_balance -= amount_to_stake + + requested_amounts.append(amount_to_stake) + + operation_plans = await build_fee_aware_stake_plan( + operation_targets=operation_targets, + requested_amounts=requested_amounts, + current_balance=current_wallet_balance, + ) + if operation_plans is None: + return + + for operation_plan in operation_plans: + operation_target = operation_plan.target + netuid = operation_target.netuid + subnet_info = operation_target.subnet_info + staking_address = operation_target.staking_address + current_stake_balance = operation_target.current_stake_balance + amount_to_stake = operation_plan.amount_to_stake + price_with_tolerance = operation_plan.price_with_tolerance + extrinsic_fee = operation_plan.extrinsic_fee # Calculate slippage # TODO: Update for V3, slippage calculation is significantly different in v3 @@ -399,29 +555,17 @@ async def stake_extrinsic( # Temporary workaround - calculations without slippage current_price_float = float(subnet_info.price.tao) rate = _safe_inverse_rate(current_price_float) - price_with_tolerance = None # If we are staking safe, add price tolerance if safe_staking: if subnet_info.is_dynamic: - price_with_tolerance = current_price_float * (1 + rate_tolerance) + safe_price = price_with_tolerance or Balance.from_rao(0) _rate_with_tolerance = _safe_inverse_rate( - price_with_tolerance + float(safe_price.tao) ) # Rate only for display rate_with_tolerance = f"{_rate_with_tolerance:.4f}" - price_with_tolerance = Balance.from_tao( - price_with_tolerance - ) # Actual price to pass to extrinsic else: rate_with_tolerance = "1" - price_with_tolerance = Balance.from_rao(1) - extrinsic_fee = await get_stake_extrinsic_fee( - netuid_=netuid, - amount_=amount_to_stake, - staking_address_=staking_address, - safe_staking_=safe_staking, - price_limit=price_with_tolerance, - ) row_extension = [ f"{rate_with_tolerance} {Balance.get_unit(netuid)}/{Balance.get_unit(0)} ", f"[{'dark_sea_green3' if allow_partial_stake else 'red'}]" @@ -429,16 +573,13 @@ async def stake_extrinsic( f"{allow_partial_stake}[/{'dark_sea_green3' if allow_partial_stake else 'red'}]", ] else: - extrinsic_fee = await get_stake_extrinsic_fee( - netuid_=netuid, - amount_=amount_to_stake, - staking_address_=staking_address, - safe_staking_=safe_staking, - ) row_extension = [] + # TODO this should be asyncio gathered before the for loop amount_minus_fee = ( - (amount_to_stake - extrinsic_fee) if not proxy else amount_to_stake + amount_to_stake + if proxy + else max(amount_to_stake - extrinsic_fee, Balance.from_rao(0)) ) sim_swap = await subtensor.sim_swap( origin_netuid=0, diff --git a/tests/unit_tests/test_stake_add.py b/tests/unit_tests/test_stake_add.py index 1df70fc83..4b9849b06 100644 --- a/tests/unit_tests/test_stake_add.py +++ b/tests/unit_tests/test_stake_add.py @@ -236,10 +236,166 @@ async def test_stake_add_stake_all_distributes_across_all_operations( for call in mock_subtensor.substrate.compose_call.await_args_list if call.kwargs.get("block_hash") == "0xabc123" ] - expected_amount = (Balance.from_tao(100) / 4).rao + expected_amount = ((Balance.from_tao(100) - (Balance.from_tao(0.01) * 4)) / 4).rao assert len(batched_stake_calls) == 4 assert all( call.kwargs["call_params"]["amount_staked"] == expected_amount for call in batched_stake_calls ) + + +@pytest.mark.asyncio +async def test_stake_add_stake_all_reserves_extrinsic_fees_across_operations( + mock_wallet, + mock_subtensor, +): + mock_subtensor.sim_swap = _sim_swap_side_effect() + mock_subtensor.all_subnets.return_value = [ + MockSubnetInfo(netuid=427, price_tao=1.5), + MockSubnetInfo(netuid=1, price_tao=2.0), + ] + mock_subtensor.get_extrinsic_fee = AsyncMock(return_value=Balance.from_tao(1.0)) + mock_subtensor.sign_and_send_batch_extrinsic = AsyncMock( + return_value=( + True, + "", + MagicMock(get_extrinsic_identifier=AsyncMock(return_value="0x1")), + ) + ) + + with patch( + "bittensor_cli.src.commands.stake.add.unlock_key", + return_value=MagicMock(success=True), + ): + await stake_add( + wallet=mock_wallet, + subtensor=mock_subtensor, + netuids=[427, 1], + stake_all=True, + amount=0, + prompt=False, + decline=False, + quiet=True, + all_hotkeys=False, + include_hotkeys=[TEST_SS58, ALT_HOTKEY_SS58], + exclude_hotkeys=[], + safe_staking=False, + rate_tolerance=0.05, + allow_partial_stake=True, + json_output=True, + era=16, + mev_protection=False, + proxy=None, + ) + + batched_stake_calls = [ + call + for call in mock_subtensor.substrate.compose_call.await_args_list + if call.kwargs.get("block_hash") == "0xabc123" + ] + expected_amount = ((Balance.from_tao(100) - (Balance.from_tao(1.0) * 4)) / 4).rao + + assert len(batched_stake_calls) == 4 + assert all( + call.kwargs["call_params"]["amount_staked"] == expected_amount + for call in batched_stake_calls + ) + + +@pytest.mark.asyncio +async def test_stake_add_multi_target_aborts_if_fees_make_plan_unaffordable( + mock_wallet, + mock_subtensor, +): + mock_subtensor.sim_swap = _sim_swap_side_effect() + mock_subtensor.all_subnets.return_value = [ + MockSubnetInfo(netuid=427, price_tao=1.5), + MockSubnetInfo(netuid=1, price_tao=2.0), + ] + mock_subtensor.get_balance = AsyncMock(return_value=Balance.from_tao(1.0)) + mock_subtensor.get_extrinsic_fee = AsyncMock(return_value=Balance.from_tao(0.1)) + mock_subtensor.sign_and_send_batch_extrinsic = AsyncMock( + return_value=( + True, + "", + MagicMock(get_extrinsic_identifier=AsyncMock(return_value="0x1")), + ) + ) + + with patch( + "bittensor_cli.src.commands.stake.add.unlock_key", + return_value=MagicMock(success=True), + ) as mock_unlock: + await stake_add( + wallet=mock_wallet, + subtensor=mock_subtensor, + netuids=[427, 1], + stake_all=False, + amount=0.5, + prompt=False, + decline=False, + quiet=True, + all_hotkeys=False, + include_hotkeys=[TEST_SS58], + exclude_hotkeys=[], + safe_staking=False, + rate_tolerance=0.05, + allow_partial_stake=True, + json_output=True, + era=16, + mev_protection=False, + proxy=None, + ) + + mock_unlock.assert_not_called() + mock_subtensor.sign_and_send_batch_extrinsic.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_stake_add_stake_all_aborts_if_fees_exceed_balance( + mock_wallet, + mock_subtensor, +): + mock_subtensor.sim_swap = _sim_swap_side_effect() + mock_subtensor.all_subnets.return_value = [ + MockSubnetInfo(netuid=427, price_tao=1.5), + MockSubnetInfo(netuid=1, price_tao=2.0), + ] + mock_subtensor.get_balance = AsyncMock(return_value=Balance.from_tao(1.0)) + mock_subtensor.get_extrinsic_fee = AsyncMock(return_value=Balance.from_tao(1.0)) + mock_subtensor.sign_and_send_batch_extrinsic = AsyncMock( + return_value=( + True, + "", + MagicMock(get_extrinsic_identifier=AsyncMock(return_value="0x1")), + ) + ) + + with patch( + "bittensor_cli.src.commands.stake.add.unlock_key", + return_value=MagicMock(success=True), + ) as mock_unlock: + await stake_add( + wallet=mock_wallet, + subtensor=mock_subtensor, + netuids=[427, 1], + stake_all=True, + amount=0, + prompt=False, + decline=False, + quiet=True, + all_hotkeys=False, + include_hotkeys=[TEST_SS58], + exclude_hotkeys=[], + safe_staking=False, + rate_tolerance=0.05, + allow_partial_stake=True, + json_output=True, + era=16, + mev_protection=False, + proxy=None, + ) + + mock_unlock.assert_not_called() + mock_subtensor.sign_and_send_batch_extrinsic.assert_not_awaited()