Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions matrix/bot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import inspect
import asyncio
import logging

Expand Down Expand Up @@ -113,7 +114,7 @@ def check(self, func: Callback) -> None:

:raises TypeError: If the function is not a coroutine.
"""
if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
Copy link
Contributor Author

@chrisdedman chrisdedman Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason of the change
Image

raise TypeError("Checks must be coroutine")

self.checks.append(func)
Expand Down Expand Up @@ -161,7 +162,7 @@ async def welcome(room, event):
"""

def wrapper(f: Callback) -> Callback:
if not asyncio.iscoroutinefunction(f):
if not inspect.iscoroutinefunction(f):
raise TypeError("Event handlers must be coroutines")

if event_spec:
Expand Down Expand Up @@ -285,7 +286,7 @@ def schedule(self, cron: str) -> Callable[..., Callback]:
"""

def wrapper(f: Callback) -> Callback:
if not asyncio.iscoroutinefunction(f):
if not inspect.iscoroutinefunction(f):
raise TypeError("Scheduled tasks must be coroutines")

self.scheduler.schedule(cron, f)
Expand Down Expand Up @@ -324,7 +325,7 @@ def error(self, exception: Optional[type[Exception]] = None) -> Callable:
"""

def wrapper(func: ErrorCallback) -> Callable:
if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
raise TypeError("The error handler must be a coroutine.")

if exception:
Expand All @@ -351,7 +352,7 @@ def _auto_register_events(self) -> None:
if not attr.startswith("on_"):
continue
coro = getattr(self, attr, None)
if asyncio.iscoroutinefunction(coro):
if inspect.iscoroutinefunction(coro):
try:
self.event(coro)
except ValueError: # ignore unknown name
Expand Down
25 changes: 14 additions & 11 deletions matrix/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def callback(self, func: Callback) -> None:
:type func: Callback
:raises TypeError: If the provided function is not a coroutine.
"""
if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
raise TypeError("Commands must be coroutines")

self._callback = func
Expand Down Expand Up @@ -134,19 +134,22 @@ def _build_usage(self) -> str:
return f"{self.prefix}{command_name} {params}"

def _parse_arguments(self, ctx: "Context") -> list[Any]:
args = ctx.args
parsed_args = []

for i, param in enumerate(self.params):
param_type = self.type_hints.get(param.name, str)

if i >= len(ctx.args):
if i >= len(args):
if param.default is not inspect.Parameter.empty:
parsed_args.append(param.default)
else:
raise MissingArgumentError(param)
continue
continue
raise MissingArgumentError(param)

if param.kind is inspect.Parameter.VAR_POSITIONAL:
parsed_args.extend(param_type(arg) for arg in args[i:])
return parsed_args

converted_arg = param_type(ctx.args[i])
converted_arg = param_type(args[i])
parsed_args.append(converted_arg)

return parsed_args
Expand All @@ -160,7 +163,7 @@ def check(self, func: Callback) -> None:

:raises TypeError: If the function is not a coroutine.
"""
if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
raise TypeError("Checks must be coroutine")

self.checks.append(func)
Expand Down Expand Up @@ -202,7 +205,7 @@ def before_invoke(self, func: Callback) -> None:
:raises TypeError: If the function is not a coroutine.
"""

if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
raise TypeError("The hook must be a coroutine.")

self._before_invoke_callback = func
Expand All @@ -217,7 +220,7 @@ def after_invoke(self, func: Callback) -> None:
:raises TypeError: If the function is not a coroutine.
"""

if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
raise TypeError("The hook must be a coroutine.")

self._after_invoke_callback = func
Expand All @@ -234,7 +237,7 @@ def error(self, exception: Optional[type[Exception]] = None) -> Callable:
"""

def wrapper(func: ErrorCallback) -> Callable:
if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
raise TypeError("The error handler must be a coroutine.")

if exception:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ async def my_command(ctx, a: int, b: str = "default"):
cmd._parse_arguments(ctx3)


def test_parse_var_positional_arguments():
async def my_command(ctx, *words: str):
pass

cmd = Command(my_command)
ctx = DummyContext(args=["hello", "matrix", "world"])

args = cmd._parse_arguments(ctx)

assert args == ["hello", "matrix", "world"]

ctx2 = DummyContext(args=[])
with pytest.raises(MissingArgumentError):
cmd._parse_arguments(ctx2)


@pytest.mark.asyncio
async def test_command_call():
called = False
Expand Down