Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions bittensor_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4750,7 +4750,6 @@ def stake_add(
else:
exclude_hotkeys = []

# TODO: Ask amount for each subnet explicitly if more than one
if not stake_all and not amount:
free_balance = self._run_command(
wallets.wallet_balance(
Expand All @@ -4762,23 +4761,55 @@ def stake_add(
if free_balance == Balance.from_tao(0):
print_error("You dont have any balance to stake.")
return
if netuids:

# If netuids is provided and has multiple subnets, ask for amount per netuid
if netuids and len(netuids) > 1:
amounts = []
remaining_balance = free_balance
for netuid in netuids:
netuid_amount = FloatPrompt.ask(
f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake to netuid {netuid} (TAO τ)[/] "
f"[dim](remaining balance: {remaining_balance})[/dim]"
)
if netuid_amount <= 0:
print_error(
f"You entered an incorrect stake amount: {netuid_amount}"
)
raise typer.Exit()
if Balance.from_tao(netuid_amount) > remaining_balance:
print_error(
f"You dont have enough balance to stake. Remaining balance: {remaining_balance}."
)
raise typer.Exit()
amounts.append(netuid_amount)
remaining_balance -= Balance.from_tao(netuid_amount)
amount = amounts
elif netuids:
# Single netuid
amount = FloatPrompt.ask(
f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake (TAO τ)"
)
if amount <= 0:
print_error(f"You entered an incorrect stake amount: {amount}")
raise typer.Exit()
if Balance.from_tao(amount) > free_balance:
print_error(
f"You dont have enough balance to stake. Current free Balance: {free_balance}."
)
raise typer.Exit()
else:
# netuids is empty list or None (all subnets) - ask for amount per netuid
amount = FloatPrompt.ask(
f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake to each netuid (TAO τ)"
)

if amount <= 0:
print_error(f"You entered an incorrect stake amount: {amount}")
raise typer.Exit()
if Balance.from_tao(amount) > free_balance:
print_error(
f"You dont have enough balance to stake. Current free Balance: {free_balance}."
)
raise typer.Exit()
if amount <= 0:
print_error(f"You entered an incorrect stake amount: {amount}")
raise typer.Exit()
if Balance.from_tao(amount) > free_balance:
print_error(
f"You dont have enough balance to stake. Current free Balance: {free_balance}."
)
raise typer.Exit()
logger.debug(
"args:\n"
f"network: {network}\n"
Expand Down
20 changes: 15 additions & 5 deletions bittensor_cli/src/commands/stake/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from functools import partial

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from async_substrate_interface import AsyncExtrinsicReceipt
from rich.table import Table
Expand Down Expand Up @@ -39,7 +39,7 @@ async def stake_add(
subtensor: "SubtensorInterface",
netuids: Optional[list[int]],
stake_all: bool,
amount: float,
amount: Union[float, list[float]],
prompt: bool,
decline: bool,
quiet: bool,
Expand All @@ -60,7 +60,7 @@ async def stake_add(
subtensor: SubtensorInterface object
netuids: the netuids to stake to (None indicates all subnets)
stake_all: whether to stake all available balance
amount: specified amount of balance to stake
amount: specified amount of balance to stake (float for single amount, list[float] for per-netuid amounts)
prompt: whether to prompt the user
all_hotkeys: whether to stake all hotkeys
include_hotkeys: list of hotkeys to include in staking process (if not specifying `--all`)
Expand Down Expand Up @@ -351,8 +351,14 @@ async def stake_extrinsic(
remaining_wallet_balance = current_wallet_balance
max_slippage = 0.0

# Convert amount to a list if it's a list, otherwise use single amount for all netuids
amount_list = None
if isinstance(amount, list):
# amount is a list of amounts per netuid
amount_list = amount

for hotkey in hotkeys_to_stake_to:
for netuid in netuids:
for netuid_idx, netuid in enumerate(netuids):
# Check that the subnet exists.
subnet_info = all_subnets.get(netuid)
if not subnet_info:
Expand All @@ -362,7 +368,11 @@ async def stake_extrinsic(

# Get the amount.
amount_to_stake = Balance(0)
if amount:
if amount_list:
# Use the amount from the list for this specific netuid
amount_to_stake = Balance.from_tao(amount_list[netuid_idx])
elif amount:
# Single amount for all netuids
amount_to_stake = Balance.from_tao(amount)
elif stake_all:
amount_to_stake = current_wallet_balance / len(netuids)
Expand Down
51 changes: 51 additions & 0 deletions tests/e2e_tests/test_staking_sudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,57 @@ def line(key: str) -> Union[str, bool]:
assert line("error_messages") == ""
assert isinstance(line("extrinsic_ids"), str)

# Test staking with prompted amounts for each netuid
add_stake_prompted = exec_command_alice(
command="stake",
sub_command="add",
extra_args=[
"--netuids",
",".join(str(x) for x in multiple_netuids),
"--wallet-path",
wallet_path_alice,
"--wallet-name",
wallet_alice.name,
"--hotkey",
wallet_alice.hotkey_str,
"--chain",
"ws://127.0.0.1:9945",
"--tolerance",
"0.1",
"--partial",
"--era",
"32",
"--json-output",
"--no-prompt",
# Note: No --amount flag, will trigger prompts
],
inputs=["50", "30"], # 50 TAO for netuid 2, 30 TAO for netuid 3
)

# Verify prompts appeared in output
assert "stake to netuid 2" in add_stake_prompted.stdout
assert "stake to netuid 3" in add_stake_prompted.stdout
assert "remaining balance" in add_stake_prompted.stdout

# TODO: Parse and verify the final staking json output
# add_stake_prompted_output = json.loads(add_stake_prompted.stdout)
# for netuid_ in multiple_netuids:

# def line_prompted(key: str) -> Union[str, bool]:
# return add_stake_prompted_output[key][str(netuid_)][
# wallet_alice.hotkey.ss58_address
# ]

# assert line_prompted("staking_success") is True, (
# f"Staking to netuid {netuid_} should succeed"
# )
# assert line_prompted("error_messages") == "", (
# f"No error messages expected for netuid {netuid_}"
# )
# assert isinstance(line_prompted("extrinsic_ids"), str), (
# f"Extrinsic ID should be a string for netuid {netuid_}"
# )

# Fetch the hyperparameters of the subnet
hyperparams = exec_command_alice(
command="sudo",
Expand Down