diff --git a/matrix/command.py b/matrix/command.py index 5c66237..20c0e62 100644 --- a/matrix/command.py +++ b/matrix/command.py @@ -1,15 +1,18 @@ -import asyncio import inspect +import types from typing import ( TYPE_CHECKING, Any, + Union, Optional, Callable, Coroutine, List, get_type_hints, DefaultDict, + get_args, + get_origin, ) from .errors import MissingArgumentError, CheckError, CooldownError @@ -139,6 +142,7 @@ def _parse_arguments(self, ctx: "Context") -> list[Any]: for i, param in enumerate(self.params): param_type = self.type_hints.get(param.name, str) + if i >= len(args): if param.default is not inspect.Parameter.empty: parsed_args.append(param.default) @@ -146,14 +150,38 @@ def _parse_arguments(self, ctx: "Context") -> list[Any]: raise MissingArgumentError(param) if param.kind is inspect.Parameter.VAR_POSITIONAL: - parsed_args.extend(param_type(arg) for arg in args[i:]) + parsed_args.extend( + self._convert_type(param_type, arg) for arg in args[i:] + ) return parsed_args - converted_arg = param_type(args[i]) + converted_arg = self._convert_type(param_type, args[i]) parsed_args.append(converted_arg) return parsed_args + def _convert_type(self, param_type: type, value: str) -> Any: + origin = get_origin(param_type) + + if origin is Union or isinstance(param_type, types.UnionType): + union_types = get_args(param_type) + + for union_type in union_types: + if union_type is type(None): + continue + + try: + return union_type(value) + except (ValueError, TypeError): + continue + + return value + + try: + return param_type(value) + except (ValueError, TypeError): + return value + def check(self, func: Callback) -> None: """ Register a check callback diff --git a/matrix/help/help_command.py b/matrix/help/help_command.py index 231a333..e44521a 100644 --- a/matrix/help/help_command.py +++ b/matrix/help/help_command.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Union, Optional, List from abc import ABC, abstractmethod from matrix.context import Context @@ -187,13 +187,9 @@ async def show_help_page(self, ctx: Context, page_number: int = 1) -> None: await ctx.reply(help_message) def parse_help_arguments( - self, args: List[str] + self, args: List[str | int] ) -> tuple[Optional[str], Optional[str], int]: - """Parse help command arguments to determine what to show. - - :param args: List of arguments passed to help command - :return: Tuple of (command_name, subcommand_name, page_number) - """ + """Parse help command arguments to determine what to show.""" command_name = None subcommand_name = None page_number = 1 @@ -201,29 +197,39 @@ def parse_help_arguments( if not args: return command_name, subcommand_name, page_number - # Check if first argument is a page number - if len(args) == 1 and args[0].isdigit(): - page_number = int(args[0]) + first_arg = args[0] + if len(args) == 1 and ( + isinstance(first_arg, int) + or (isinstance(first_arg, str) and first_arg.isdigit()) + ): + page_number = int(first_arg) return command_name, subcommand_name, page_number - command_name = args[0] + command_name = str(first_arg) if len(args) >= 2: - if args[1].isdigit(): - page_number = int(args[1]) + second_arg = args[1] + if isinstance(second_arg, int) or ( + isinstance(second_arg, str) and second_arg.isdigit() + ): + page_number = int(second_arg) else: - subcommand_name = args[1] + subcommand_name = str(second_arg) - if len(args) >= 3 and args[2].isdigit(): - page_number = int(args[2]) + if len(args) >= 3: + third_arg = args[2] + if isinstance(third_arg, int) or ( + isinstance(third_arg, str) and third_arg.isdigit() + ): + page_number = int(third_arg) return command_name, subcommand_name, page_number async def execute( self, ctx: Context, - cmd_or_page: str | None = None, - subcommand: str | None = None, + cmd_or_page: Union[str, int, None] = None, + subcommand: Union[str | None] = None, ) -> None: """ Execute the help command using show_command_help and show_group_help. diff --git a/tests/test_command.py b/tests/test_command.py index ac9c4ba..5e729a8 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -251,3 +251,90 @@ async def always_fails(ctx): with pytest.raises(Exception): await cmd(ctx) assert called is False + + +def test_parse_arguments_with_union_type__expect_successful_conversion(): + async def my_command(ctx, value: str | int): + pass + + cmd = Command(my_command) + + ctx = DummyContext(args=["123"]) + args = cmd._parse_arguments(ctx) + assert args[0] in [123, "123"] # Accept either, depending on Union order + + ctx2 = DummyContext(args=["hello"]) + args2 = cmd._parse_arguments(ctx2) + assert args2 == ["hello"] + + +def test_parse_arguments_with_optional_union__expect_default_none(): + async def my_command(ctx, count: int | None = None): + pass + + cmd = Command(my_command) + + ctx = DummyContext(args=["42"]) + args = cmd._parse_arguments(ctx) + assert args == [42] + + ctx2 = DummyContext(args=[]) + args2 = cmd._parse_arguments(ctx2) + assert args2 == [None] + + +def test_parse_arguments_with_multiple_union_types__expect_first_successful(): + async def my_command(ctx, value: int | str): + pass + + cmd = Command(my_command) + + ctx = DummyContext(args=["42"]) + args = cmd._parse_arguments(ctx) + assert args == [42] + + ctx2 = DummyContext(args=["not-a-number"]) + args2 = cmd._parse_arguments(ctx2) + assert args2 == ["not-a-number"] + + +def test_parse_arguments_with_union_and_default__expect_typed_conversion(): + """Test Union types with default values.""" + + async def my_command(ctx, port: int | str = 8080): + pass + + cmd = Command(my_command) + + ctx = DummyContext(args=["3000"]) + args = cmd._parse_arguments(ctx) + assert args[0] in [3000, "3000"] + + ctx2 = DummyContext(args=[]) + args2 = cmd._parse_arguments(ctx2) + assert args2 == [8080] + + +def test_parse_arguments_with_union_var_positional__expect_all_converted(): + async def my_command(ctx, *values: int | str): + pass + + cmd = Command(my_command) + + ctx = DummyContext(args=["1", "hello", "3"]) + args = cmd._parse_arguments(ctx) + + assert len(args) == 3 + assert args[1] == "hello" + + +def test_parse_arguments_with_union_conversion_failure__expect_string_fallback(): + async def my_command(ctx, value: int | float): + pass + + cmd = Command(my_command) + + ctx = DummyContext(args=["hello"]) + args = cmd._parse_arguments(ctx) + + assert args == ["hello"]