Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions matrix/command.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import asyncio
import inspect
import types

from typing import (
TYPE_CHECKING,
Any,
Union,
Optional,
Callable,
Coroutine,
List,
get_type_hints,
DefaultDict,
get_args,
get_origin,
)

from .errors import MissingArgumentError, CheckError, CooldownError
Expand Down Expand Up @@ -139,21 +142,46 @@ def _parse_arguments(self, ctx: "Context") -> list[Any]:

for i, param in enumerate(self.params):
param_type = self.type_hints.get(param.name, str)

if i >= len(args):
if param.default is not inspect.Parameter.empty:
parsed_args.append(param.default)
continue
raise MissingArgumentError(param)

if param.kind is inspect.Parameter.VAR_POSITIONAL:
parsed_args.extend(param_type(arg) for arg in args[i:])
parsed_args.extend(
self._convert_type(param_type, arg) for arg in args[i:]
)
return parsed_args

converted_arg = param_type(args[i])
converted_arg = self._convert_type(param_type, args[i])
parsed_args.append(converted_arg)

return parsed_args

def _convert_type(self, param_type: type, value: str) -> Any:
origin = get_origin(param_type)

if origin is Union or isinstance(param_type, types.UnionType):
union_types = get_args(param_type)

for union_type in union_types:
if union_type is type(None):
continue

try:
return union_type(value)
except (ValueError, TypeError):
continue

return value

try:
return param_type(value)
except (ValueError, TypeError):
return value

def check(self, func: Callback) -> None:
"""
Register a check callback
Expand Down
42 changes: 24 additions & 18 deletions matrix/help/help_command.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Union, Optional, List
from abc import ABC, abstractmethod

from matrix.context import Context
Expand Down Expand Up @@ -187,43 +187,49 @@ async def show_help_page(self, ctx: Context, page_number: int = 1) -> None:
await ctx.reply(help_message)

def parse_help_arguments(
self, args: List[str]
self, args: List[str | int]
) -> tuple[Optional[str], Optional[str], int]:
"""Parse help command arguments to determine what to show.

:param args: List of arguments passed to help command
:return: Tuple of (command_name, subcommand_name, page_number)
"""
"""Parse help command arguments to determine what to show."""
command_name = None
subcommand_name = None
page_number = 1

if not args:
return command_name, subcommand_name, page_number

# Check if first argument is a page number
if len(args) == 1 and args[0].isdigit():
page_number = int(args[0])
first_arg = args[0]
if len(args) == 1 and (
isinstance(first_arg, int)
or (isinstance(first_arg, str) and first_arg.isdigit())
):
page_number = int(first_arg)
return command_name, subcommand_name, page_number

command_name = args[0]
command_name = str(first_arg)

if len(args) >= 2:
if args[1].isdigit():
page_number = int(args[1])
second_arg = args[1]
if isinstance(second_arg, int) or (
isinstance(second_arg, str) and second_arg.isdigit()
):
page_number = int(second_arg)
else:
subcommand_name = args[1]
subcommand_name = str(second_arg)

if len(args) >= 3 and args[2].isdigit():
page_number = int(args[2])
if len(args) >= 3:
third_arg = args[2]
if isinstance(third_arg, int) or (
isinstance(third_arg, str) and third_arg.isdigit()
):
page_number = int(third_arg)

return command_name, subcommand_name, page_number

async def execute(
self,
ctx: Context,
cmd_or_page: str | None = None,
subcommand: str | None = None,
cmd_or_page: Union[str, int, None] = None,
subcommand: Union[str | None] = None,
) -> None:
"""
Execute the help command using show_command_help and show_group_help.
Expand Down
87 changes: 87 additions & 0 deletions tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,90 @@ async def always_fails(ctx):
with pytest.raises(Exception):
await cmd(ctx)
assert called is False


def test_parse_arguments_with_union_type__expect_successful_conversion():
async def my_command(ctx, value: str | int):
pass

cmd = Command(my_command)

ctx = DummyContext(args=["123"])
args = cmd._parse_arguments(ctx)
assert args[0] in [123, "123"] # Accept either, depending on Union order

ctx2 = DummyContext(args=["hello"])
args2 = cmd._parse_arguments(ctx2)
assert args2 == ["hello"]


def test_parse_arguments_with_optional_union__expect_default_none():
async def my_command(ctx, count: int | None = None):
pass

cmd = Command(my_command)

ctx = DummyContext(args=["42"])
args = cmd._parse_arguments(ctx)
assert args == [42]

ctx2 = DummyContext(args=[])
args2 = cmd._parse_arguments(ctx2)
assert args2 == [None]


def test_parse_arguments_with_multiple_union_types__expect_first_successful():
async def my_command(ctx, value: int | str):
pass

cmd = Command(my_command)

ctx = DummyContext(args=["42"])
args = cmd._parse_arguments(ctx)
assert args == [42]

ctx2 = DummyContext(args=["not-a-number"])
args2 = cmd._parse_arguments(ctx2)
assert args2 == ["not-a-number"]


def test_parse_arguments_with_union_and_default__expect_typed_conversion():
"""Test Union types with default values."""

async def my_command(ctx, port: int | str = 8080):
pass

cmd = Command(my_command)

ctx = DummyContext(args=["3000"])
args = cmd._parse_arguments(ctx)
assert args[0] in [3000, "3000"]

ctx2 = DummyContext(args=[])
args2 = cmd._parse_arguments(ctx2)
assert args2 == [8080]


def test_parse_arguments_with_union_var_positional__expect_all_converted():
async def my_command(ctx, *values: int | str):
pass

cmd = Command(my_command)

ctx = DummyContext(args=["1", "hello", "3"])
args = cmd._parse_arguments(ctx)

assert len(args) == 3
assert args[1] == "hello"


def test_parse_arguments_with_union_conversion_failure__expect_string_fallback():
async def my_command(ctx, value: int | float):
pass

cmd = Command(my_command)

ctx = DummyContext(args=["hello"])
args = cmd._parse_arguments(ctx)

assert args == ["hello"]