Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 173 additions & 32 deletions bittensor_cli/src/commands/stake/add.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from collections import defaultdict
from dataclasses import dataclass
from functools import partial

from typing import TYPE_CHECKING, Optional
Expand Down Expand Up @@ -30,9 +31,26 @@
from bittensor_wallet import Wallet

if TYPE_CHECKING:
from bittensor_cli.src.bittensor.chain_data import DynamicInfo
from bittensor_cli.src.bittensor.subtensor_interface import SubtensorInterface


@dataclass(frozen=True)
class StakeOperationTarget:
staking_address: str
netuid: int
subnet_info: "DynamicInfo"
current_stake_balance: Balance


@dataclass(frozen=True)
class StakeOperationPlan:
target: StakeOperationTarget
amount_to_stake: Balance
extrinsic_fee: Balance
price_with_tolerance: Optional[Balance]


# Command
async def stake_add(
wallet: Wallet,
Expand Down Expand Up @@ -105,7 +123,7 @@ async def get_stake_extrinsic_fee(
if safe_staking_:
call_params.update(
{
"limit_price": price_limit,
"limit_price": price_limit.rao if price_limit else None,
"allow_partial": allow_partial_stake,
}
)
Expand All @@ -116,6 +134,118 @@ async def get_stake_extrinsic_fee(
)
return await subtensor.get_extrinsic_fee(call, wallet.coldkeypub, proxy=proxy)

def _build_price_with_tolerance(
operation_target: StakeOperationTarget,
) -> Optional[Balance]:
if not safe_staking:
return None

current_price_float = float(operation_target.subnet_info.price.tao)
if operation_target.subnet_info.is_dynamic:
return Balance.from_tao(current_price_float * (1 + rate_tolerance))
return Balance.from_rao(1)

def _validate_stake_plan_affordability(
operation_plans: list[StakeOperationPlan],
current_balance: Balance,
) -> bool:
if proxy:
return True

total_required = Balance.from_rao(0)
for operation_plan in operation_plans:
total_required += (
operation_plan.amount_to_stake + operation_plan.extrinsic_fee
)

if total_required > current_balance:
print_error(
f"Not enough balance to cover stake operations and extrinsic fees:\n"
f" required: {total_required} > balance: {current_balance}"
)
return False

return True

async def build_fee_aware_stake_plan(
operation_targets: list[StakeOperationTarget],
requested_amounts: list[Balance],
current_balance: Balance,
) -> Optional[list[StakeOperationPlan]]:
if not operation_targets:
return []

price_limits = [
_build_price_with_tolerance(operation_target)
for operation_target in operation_targets
]
planned_amounts = list(requested_amounts)

if stake_all:
per_operation_amount = current_balance / len(operation_targets)

if not proxy:
for _ in range(3):
fee_estimates = await asyncio.gather(
*[
get_stake_extrinsic_fee(
netuid_=operation_target.netuid,
amount_=per_operation_amount,
staking_address_=operation_target.staking_address,
safe_staking_=safe_staking,
price_limit=price_limits[idx],
)
for idx, operation_target in enumerate(operation_targets)
]
)

total_fee = Balance.from_rao(0)
for fee_estimate in fee_estimates:
total_fee += fee_estimate

available_for_staking = current_balance - total_fee
if available_for_staking <= Balance.from_rao(0):
print_error(
"Not enough balance to cover extrinsic fees for stake-all operations."
)
return None

adjusted_amount = available_for_staking / len(operation_targets)
if adjusted_amount == per_operation_amount:
break

per_operation_amount = adjusted_amount

planned_amounts = [per_operation_amount for _ in operation_targets]

fee_estimates = await asyncio.gather(
*[
get_stake_extrinsic_fee(
netuid_=operation_target.netuid,
amount_=planned_amounts[idx],
staking_address_=operation_target.staking_address,
safe_staking_=safe_staking,
price_limit=price_limits[idx],
)
for idx, operation_target in enumerate(operation_targets)
]
)

operation_plans = [
StakeOperationPlan(
target=operation_target,
amount_to_stake=planned_amounts[idx],
extrinsic_fee=fee_estimates[idx],
price_with_tolerance=price_limits[idx],
)
for idx, operation_target in enumerate(operation_targets)
]

if not _validate_stake_plan_affordability(operation_plans, current_balance):
return None

return operation_plans

async def safe_stake_extrinsic(
netuid_: int,
amount_: Balance,
Expand Down Expand Up @@ -339,7 +469,7 @@ async def stake_extrinsic(
)

# Determine the amount we are staking.
operation_targets = []
operation_targets: list[StakeOperationTarget] = []
for hotkey in hotkeys_to_stake_to:
for netuid in netuids:
# Check that the subnet exists.
Expand All @@ -348,7 +478,12 @@ async def stake_extrinsic(
print_error(f"Subnet with netuid: {netuid} does not exist.")
continue
operation_targets.append(
(hotkey, netuid, subnet_info, hotkey_stake_map[hotkey[1]][netuid])
StakeOperationTarget(
staking_address=hotkey[1],
netuid=netuid,
subnet_info=subnet_info,
current_stake_balance=hotkey_stake_map[hotkey[1]][netuid],
)
)

if stake_all and not operation_targets:
Expand All @@ -359,31 +494,52 @@ async def stake_extrinsic(
operations = []
remaining_wallet_balance = current_wallet_balance
max_slippage = 0.0
requested_amounts: list[Balance] = []

for hotkey, netuid, subnet_info, current_stake_balance in operation_targets:
staking_address = hotkey[1]

for operation_target in operation_targets:
# Get the amount.
amount_to_stake = Balance(0)
if amount:
amount_to_stake = Balance.from_tao(amount)
elif stake_all:
amount_to_stake = current_wallet_balance / len(operation_targets)
amount_to_stake = Balance.from_rao(0)
elif not amount:
amount_to_stake, _ = _prompt_stake_amount(
current_balance=remaining_wallet_balance,
netuid=netuid,
netuid=operation_target.netuid,
action_name="stake",
)

# Check enough to stake.
if amount_to_stake > remaining_wallet_balance:
# Check enough to stake, excluding extrinsic fees (validated after fee-aware planning).
if not stake_all and amount_to_stake > remaining_wallet_balance:
print_error(
f"Not enough stake:[bold white]\n wallet balance:{remaining_wallet_balance} < "
f"staking amount: {amount_to_stake}[/bold white]"
)
return
remaining_wallet_balance -= amount_to_stake

if not stake_all:
remaining_wallet_balance -= amount_to_stake

requested_amounts.append(amount_to_stake)

operation_plans = await build_fee_aware_stake_plan(
operation_targets=operation_targets,
requested_amounts=requested_amounts,
current_balance=current_wallet_balance,
)
if operation_plans is None:
return

for operation_plan in operation_plans:
operation_target = operation_plan.target
netuid = operation_target.netuid
subnet_info = operation_target.subnet_info
staking_address = operation_target.staking_address
current_stake_balance = operation_target.current_stake_balance
amount_to_stake = operation_plan.amount_to_stake
price_with_tolerance = operation_plan.price_with_tolerance
extrinsic_fee = operation_plan.extrinsic_fee

# Calculate slippage
# TODO: Update for V3, slippage calculation is significantly different in v3
Expand All @@ -399,46 +555,31 @@ async def stake_extrinsic(
# Temporary workaround - calculations without slippage
current_price_float = float(subnet_info.price.tao)
rate = _safe_inverse_rate(current_price_float)
price_with_tolerance = None

# If we are staking safe, add price tolerance
if safe_staking:
if subnet_info.is_dynamic:
price_with_tolerance = current_price_float * (1 + rate_tolerance)
safe_price = price_with_tolerance or Balance.from_rao(0)
_rate_with_tolerance = _safe_inverse_rate(
price_with_tolerance
float(safe_price.tao)
) # Rate only for display
rate_with_tolerance = f"{_rate_with_tolerance:.4f}"
price_with_tolerance = Balance.from_tao(
price_with_tolerance
) # Actual price to pass to extrinsic
else:
rate_with_tolerance = "1"
price_with_tolerance = Balance.from_rao(1)
extrinsic_fee = await get_stake_extrinsic_fee(
netuid_=netuid,
amount_=amount_to_stake,
staking_address_=staking_address,
safe_staking_=safe_staking,
price_limit=price_with_tolerance,
)
row_extension = [
f"{rate_with_tolerance} {Balance.get_unit(netuid)}/{Balance.get_unit(0)} ",
f"[{'dark_sea_green3' if allow_partial_stake else 'red'}]"
# safe staking
f"{allow_partial_stake}[/{'dark_sea_green3' if allow_partial_stake else 'red'}]",
]
else:
extrinsic_fee = await get_stake_extrinsic_fee(
netuid_=netuid,
amount_=amount_to_stake,
staking_address_=staking_address,
safe_staking_=safe_staking,
)
row_extension = []

# TODO this should be asyncio gathered before the for loop
amount_minus_fee = (
(amount_to_stake - extrinsic_fee) if not proxy else amount_to_stake
amount_to_stake
if proxy
else max(amount_to_stake - extrinsic_fee, Balance.from_rao(0))
)
sim_swap = await subtensor.sim_swap(
origin_netuid=0,
Expand Down
Loading
Loading