diff --git a/matrix/bot.py b/matrix/bot.py index ab768d4..2dfb7cd 100644 --- a/matrix/bot.py +++ b/matrix/bot.py @@ -1,4 +1,5 @@ import time +import inspect import asyncio import logging @@ -113,7 +114,7 @@ def check(self, func: Callback) -> None: :raises TypeError: If the function is not a coroutine. """ - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Checks must be coroutine") self.checks.append(func) @@ -161,7 +162,7 @@ async def welcome(room, event): """ def wrapper(f: Callback) -> Callback: - if not asyncio.iscoroutinefunction(f): + if not inspect.iscoroutinefunction(f): raise TypeError("Event handlers must be coroutines") if event_spec: @@ -285,7 +286,7 @@ def schedule(self, cron: str) -> Callable[..., Callback]: """ def wrapper(f: Callback) -> Callback: - if not asyncio.iscoroutinefunction(f): + if not inspect.iscoroutinefunction(f): raise TypeError("Scheduled tasks must be coroutines") self.scheduler.schedule(cron, f) @@ -324,7 +325,7 @@ def error(self, exception: Optional[type[Exception]] = None) -> Callable: """ def wrapper(func: ErrorCallback) -> Callable: - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("The error handler must be a coroutine.") if exception: @@ -351,7 +352,7 @@ def _auto_register_events(self) -> None: if not attr.startswith("on_"): continue coro = getattr(self, attr, None) - if asyncio.iscoroutinefunction(coro): + if inspect.iscoroutinefunction(coro): try: self.event(coro) except ValueError: # ignore unknown name diff --git a/matrix/command.py b/matrix/command.py index f8809a6..5c66237 100644 --- a/matrix/command.py +++ b/matrix/command.py @@ -98,7 +98,7 @@ def callback(self, func: Callback) -> None: :type func: Callback :raises TypeError: If the provided function is not a coroutine. """ - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Commands must be coroutines") self._callback = func @@ -134,19 +134,22 @@ def _build_usage(self) -> str: return f"{self.prefix}{command_name} {params}" def _parse_arguments(self, ctx: "Context") -> list[Any]: + args = ctx.args parsed_args = [] for i, param in enumerate(self.params): param_type = self.type_hints.get(param.name, str) - - if i >= len(ctx.args): + if i >= len(args): if param.default is not inspect.Parameter.empty: parsed_args.append(param.default) - else: - raise MissingArgumentError(param) - continue + continue + raise MissingArgumentError(param) + + if param.kind is inspect.Parameter.VAR_POSITIONAL: + parsed_args.extend(param_type(arg) for arg in args[i:]) + return parsed_args - converted_arg = param_type(ctx.args[i]) + converted_arg = param_type(args[i]) parsed_args.append(converted_arg) return parsed_args @@ -160,7 +163,7 @@ def check(self, func: Callback) -> None: :raises TypeError: If the function is not a coroutine. """ - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("Checks must be coroutine") self.checks.append(func) @@ -202,7 +205,7 @@ def before_invoke(self, func: Callback) -> None: :raises TypeError: If the function is not a coroutine. """ - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("The hook must be a coroutine.") self._before_invoke_callback = func @@ -217,7 +220,7 @@ def after_invoke(self, func: Callback) -> None: :raises TypeError: If the function is not a coroutine. """ - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("The hook must be a coroutine.") self._after_invoke_callback = func @@ -234,7 +237,7 @@ def error(self, exception: Optional[type[Exception]] = None) -> Callable: """ def wrapper(func: ErrorCallback) -> Callable: - if not asyncio.iscoroutinefunction(func): + if not inspect.iscoroutinefunction(func): raise TypeError("The error handler must be a coroutine.") if exception: diff --git a/tests/test_command.py b/tests/test_command.py index 4a9326c..ac9c4ba 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -88,6 +88,22 @@ async def my_command(ctx, a: int, b: str = "default"): cmd._parse_arguments(ctx3) +def test_parse_var_positional_arguments(): + async def my_command(ctx, *words: str): + pass + + cmd = Command(my_command) + ctx = DummyContext(args=["hello", "matrix", "world"]) + + args = cmd._parse_arguments(ctx) + + assert args == ["hello", "matrix", "world"] + + ctx2 = DummyContext(args=[]) + with pytest.raises(MissingArgumentError): + cmd._parse_arguments(ctx2) + + @pytest.mark.asyncio async def test_command_call(): called = False