diff --git a/bittensor_cli/cli.py b/bittensor_cli/cli.py index 624793887..35543a7ca 100755 --- a/bittensor_cli/cli.py +++ b/bittensor_cli/cli.py @@ -4516,6 +4516,11 @@ def stake_add( amount: float = typer.Option( 0.0, "--amount", help="The amount of TAO to stake" ), + amounts: str = typer.Option( + "", + "--amounts", + help="Comma-separated amounts of TAO to stake for each netuid. Must be used with --netuids and the number of amounts must match the number of netuids. Example: --netuids 1,2,3 --amounts 0.1,0.2,0.3", + ), include_hotkeys: str = typer.Option( "", "--include-hotkeys", @@ -4586,7 +4591,10 @@ def stake_add( 7. Stake the same amount to multiple subnets: [green]$[/green] btcli stake add --amount 100 --netuids 4,5,6 - 8. Stake without MEV protection: + 8. Stake different amounts to multiple subnets: + [green]$[/green] btcli stake add --netuids 1,2,3 --amounts 0.1,0.2,0.3 + + 9. Stake without MEV protection: [green]$[/green] btcli stake add --amount 100 --netuid 1 --no-mev-protection [bold]Safe Staking Parameters:[/bold] @@ -4616,12 +4624,57 @@ def stake_add( # ensure no negative netuids make it into our list validate_netuid(netuid_) + # Validate mutually exclusive options + if amount and amounts: + print_error( + "Cannot specify both --amount and --amounts. Use --amount for single amount or --amounts for per-netuid amounts." + ) + return + if stake_all and amount: print_error( "Cannot specify an amount and 'stake-all'. Choose one or the other." ) return + if stake_all and amounts: + print_error( + "Cannot specify --amounts and 'stake-all'. Choose one or the other." + ) + return + + # Parse and validate --amounts if provided + amounts_list = None + if amounts: + if not netuids or len(netuids) == 0: + print_error( + "--amounts can only be used with --netuids. Please specify netuids." + ) + return + try: + amounts_list = parse_to_list( + amounts, + float, + "Amounts must be numbers separated by commas, e.g., `--amounts 0.1,0.2,0.3`.", + False, + ) + if len(amounts_list) != len(netuids): + print_error( + f"Number of amounts ({len(amounts_list)}) must match number of netuids ({len(netuids)}). " + f"Netuids: {netuids}, Amounts: {amounts_list}" + ) + return + # Validate all amounts are positive + for amt in amounts_list: + if amt <= 0: + print_error( + f"All amounts must be positive. Invalid amount: {amt}" + ) + return + except Exception as e: + print_error(f"Failed to parse amounts: {e}") + return + if stake_all and not amount: if not confirm_action( "Stake all the available TAO tokens?", @@ -4747,8 +4800,10 @@ def stake_add( else: exclude_hotkeys = [] - # TODO: Ask amount for each subnet explicitly if more than one - if not stake_all and not amount: + # Use amounts_list if provided via --amounts flag + if amounts_list: + amount = amounts_list + elif not stake_all and not amount: free_balance = self._run_command( wallets.wallet_balance( wallet, self.initialize_chain(network), False, None @@ -4759,23 +4814,55 @@ def stake_add( if free_balance == Balance.from_tao(0): print_error("You dont have any balance to stake.") return - if netuids: + + # If netuids is provided and has multiple subnets, ask for amount per netuid + if netuids and len(netuids) > 1: + amounts_prompted = [] + remaining_balance = free_balance + for netuid in netuids: + netuid_amount = FloatPrompt.ask( + f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake to netuid {netuid} (TAO τ)[/] " + f"[dim](remaining balance: {remaining_balance})[/dim]" + ) + if netuid_amount <= 0: + print_error( + f"You entered an incorrect stake amount: {netuid_amount}" + ) + raise typer.Exit() + if Balance.from_tao(netuid_amount) > remaining_balance: + print_error( + f"You dont have enough balance to stake. Remaining balance: {remaining_balance}." + ) + raise typer.Exit() + amounts_prompted.append(netuid_amount) + remaining_balance -= Balance.from_tao(netuid_amount) + amount = amounts_prompted + elif netuids: + # Single netuid amount = FloatPrompt.ask( f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake (TAO τ)" ) + if amount <= 0: + print_error(f"You entered an incorrect stake amount: {amount}") + raise typer.Exit() + if Balance.from_tao(amount) > free_balance: + print_error( + f"You dont have enough balance to stake. Current free Balance: {free_balance}." + ) + raise typer.Exit() else: + # netuids is empty list or None (all subnets) - ask for amount per netuid amount = FloatPrompt.ask( f"Amount to [{COLORS.G.SUBHEAD_MAIN}]stake to each netuid (TAO τ)" ) - - if amount <= 0: - print_error(f"You entered an incorrect stake amount: {amount}") - raise typer.Exit() - if Balance.from_tao(amount) > free_balance: - print_error( - f"You dont have enough balance to stake. Current free Balance: {free_balance}." - ) - raise typer.Exit() + if amount <= 0: + print_error(f"You entered an incorrect stake amount: {amount}") + raise typer.Exit() + if Balance.from_tao(amount) > free_balance: + print_error( + f"You dont have enough balance to stake. Current free Balance: {free_balance}." + ) + raise typer.Exit() logger.debug( "args:\n" f"network: {network}\n" diff --git a/bittensor_cli/src/commands/stake/add.py b/bittensor_cli/src/commands/stake/add.py index 18c6578eb..958bd6fee 100644 --- a/bittensor_cli/src/commands/stake/add.py +++ b/bittensor_cli/src/commands/stake/add.py @@ -2,7 +2,7 @@ from collections import defaultdict from functools import partial -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from async_substrate_interface import AsyncExtrinsicReceipt from rich.table import Table @@ -38,7 +38,7 @@ async def stake_add( subtensor: "SubtensorInterface", netuids: Optional[list[int]], stake_all: bool, - amount: float, + amount: Union[float, list[float]], prompt: bool, decline: bool, quiet: bool, @@ -59,7 +59,7 @@ async def stake_add( subtensor: SubtensorInterface object netuids: the netuids to stake to (None indicates all subnets) stake_all: whether to stake all available balance - amount: specified amount of balance to stake + amount: specified amount of balance to stake (float for single amount, list[float] for per-netuid amounts) prompt: whether to prompt the user all_hotkeys: whether to stake all hotkeys include_hotkeys: list of hotkeys to include in staking process (if not specifying `--all`) @@ -350,8 +350,12 @@ async def stake_extrinsic( remaining_wallet_balance = current_wallet_balance max_slippage = 0.0 + amount_list = [] + if isinstance(amount, list): + amount_list = amount + for hotkey in hotkeys_to_stake_to: - for netuid in netuids: + for netuid_idx, netuid in enumerate(netuids): # Check that the subnet exists. subnet_info = all_subnets.get(netuid) if not subnet_info: @@ -361,7 +365,11 @@ async def stake_extrinsic( # Get the amount. amount_to_stake = Balance(0) - if amount: + if amount_list: + # Use the amount from the list for this specific netuid + amount_to_stake = Balance.from_tao(amount_list[netuid_idx]) + elif amount: + # Single amount for all netuids amount_to_stake = Balance.from_tao(amount) elif stake_all: amount_to_stake = current_wallet_balance / len(netuids) @@ -373,15 +381,6 @@ async def stake_extrinsic( ) amounts_to_stake.append(amount_to_stake) - # Check enough to stake. - if amount_to_stake > remaining_wallet_balance: - print_error( - f"Not enough stake:[bold white]\n wallet balance:{remaining_wallet_balance} < " - f"staking amount: {amount_to_stake}[/bold white]" - ) - return - remaining_wallet_balance -= amount_to_stake - # Calculate slippage # TODO: Update for V3, slippage calculation is significantly different in v3 # try: @@ -433,6 +432,20 @@ async def stake_extrinsic( safe_staking_=safe_staking, ) row_extension = [] + + # Check enough balance to cover stake amount and extrinsic fee + total_cost = ( + amount_to_stake + extrinsic_fee if not proxy else amount_to_stake + ) + if total_cost > remaining_wallet_balance: + print_error( + f"[red]Not enough stake[/red]:[bold white]\n wallet balance: {remaining_wallet_balance} < " + f"staking amount: {amount_to_stake}[/bold white]" + ) + return + + # Deduct stake amount and extrinsic fee from remaining balance + remaining_wallet_balance -= total_cost # TODO this should be asyncio gathered before the for loop amount_minus_fee = ( (amount_to_stake - extrinsic_fee) if not proxy else amount_to_stake diff --git a/tests/e2e_tests/test_staking_sudo.py b/tests/e2e_tests/test_staking_sudo.py index e76ff1627..d4c44fa94 100644 --- a/tests/e2e_tests/test_staking_sudo.py +++ b/tests/e2e_tests/test_staking_sudo.py @@ -490,6 +490,107 @@ def line(key: str) -> Union[str, bool]: assert line("error_messages") == "" assert isinstance(line("extrinsic_ids"), str) + # Test staking with prompted amounts for each netuid + add_stake_prompted = exec_command_alice( + command="stake", + sub_command="add", + extra_args=[ + "--netuids", + ",".join(str(x) for x in multiple_netuids), + "--wallet-path", + wallet_path_alice, + "--wallet-name", + wallet_alice.name, + "--hotkey", + wallet_alice.hotkey_str, + "--chain", + "ws://127.0.0.1:9945", + "--tolerance", + "0.1", + "--partial", + "--era", + "32", + "--json-output", + "--no-prompt", + # Note: No --amount or --amounts flag, will trigger prompts + ], + inputs=["50", "30"], # 50 TAO for netuid 2, 30 TAO for netuid 3 + ) + + # Verify prompts appeared in output + assert "stake to netuid 2" in add_stake_prompted.stdout + assert "stake to netuid 3" in add_stake_prompted.stdout + assert "remaining balance" in add_stake_prompted.stdout + + # Extract JSON from stdout (prompts are mixed with JSON output) + json_match = re.search(r"\{.*\}", add_stake_prompted.stdout, re.DOTALL) + if json_match: + json_str = json_match.group(0) + add_stake_prompted_output = json.loads(json_str) + + for netuid_ in multiple_netuids: + + def line_prompted(key: str) -> Union[str, bool]: + return add_stake_prompted_output[key][str(netuid_)][ + wallet_alice.hotkey.ss58_address + ] + + assert line_prompted("staking_success") is True, ( + f"Staking to netuid {netuid_} should succeed" + ) + assert line_prompted("error_messages") == "", ( + f"No error messages expected for netuid {netuid_}" + ) + assert isinstance(line_prompted("extrinsic_ids"), str), ( + f"Extrinsic ID should be a string for netuid {netuid_}" + ) + + # Test staking with --amounts option for different amounts per netuid + add_stake_amounts = exec_command_alice( + command="stake", + sub_command="add", + extra_args=[ + "--netuids", + ",".join(str(x) for x in multiple_netuids), + "--amounts", + "25,15", # 25 TAO for netuid 2, 15 TAO for netuid 3 + "--wallet-path", + wallet_path_alice, + "--wallet-name", + wallet_alice.name, + "--hotkey", + wallet_alice.hotkey_str, + "--chain", + "ws://127.0.0.1:9945", + "--tolerance", + "0.1", + "--partial", + "--era", + "32", + "--json-output", + "--no-prompt", + ], + ) + + # Parse and verify the staking results for --amounts + add_stake_amounts_output = json.loads(add_stake_amounts.stdout) + for netuid_ in multiple_netuids: + + def line_amounts(key: str) -> Union[str, bool]: + return add_stake_amounts_output[key][str(netuid_)][ + wallet_alice.hotkey.ss58_address + ] + + assert line_amounts("staking_success") is True, ( + f"Staking with --amounts to netuid {netuid_} should succeed" + ) + assert line_amounts("error_messages") == "", ( + f"No error messages expected for netuid {netuid_} with --amounts" + ) + assert isinstance(line_amounts("extrinsic_ids"), str), ( + f"Extrinsic ID should be a string for netuid {netuid_} with --amounts" + ) + # Fetch the hyperparameters of the subnet hyperparams = exec_command_alice( command="sudo",