diff --git a/.gitignore b/.gitignore index 2549060e7..b28dfa9ed 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,9 @@ venv/ .venv* .env/ +# claude +.claude/*.local.json + # codecov / coverage .coverage cov_* diff --git a/slack_bolt/__init__.py b/slack_bolt/__init__.py index 6331925f8..4e43252fd 100644 --- a/slack_bolt/__init__.py +++ b/slack_bolt/__init__.py @@ -21,6 +21,7 @@ from .response import BoltResponse # AI Agents & Assistants +from .agent import BoltAgent from .middleware.assistant.assistant import ( Assistant, ) @@ -46,6 +47,7 @@ "CustomListenerMatcher", "BoltRequest", "BoltResponse", + "BoltAgent", "Assistant", "AssistantThreadContext", "AssistantThreadContextStore", diff --git a/slack_bolt/adapter/__init__.py b/slack_bolt/adapter/__init__.py index f339226bc..9ca556e52 100644 --- a/slack_bolt/adapter/__init__.py +++ b/slack_bolt/adapter/__init__.py @@ -1,2 +1 @@ -"""Adapter modules for running Bolt apps along with Web frameworks or Socket Mode. -""" +"""Adapter modules for running Bolt apps along with Web frameworks or Socket Mode.""" diff --git a/slack_bolt/agent/__init__.py b/slack_bolt/agent/__init__.py new file mode 100644 index 000000000..4d751f27f --- /dev/null +++ b/slack_bolt/agent/__init__.py @@ -0,0 +1,5 @@ +from .agent import BoltAgent + +__all__ = [ + "BoltAgent", +] diff --git a/slack_bolt/agent/agent.py b/slack_bolt/agent/agent.py new file mode 100644 index 000000000..b6b3deeeb --- /dev/null +++ b/slack_bolt/agent/agent.py @@ -0,0 +1,82 @@ +from typing import Optional + +from slack_sdk import WebClient +from slack_sdk.web.chat_stream import ChatStream + + +class BoltAgent: + """Agent listener argument for building AI-powered Slack agents. + + Experimental: + This API is experimental and may change in future releases. + + FIXME: chat_stream() only works when thread_ts is available (DMs and threaded replies). + It does not work on channel messages because ts is not provided to BoltAgent yet. + + @app.event("app_mention") + def handle_mention(agent): + stream = agent.chat_stream() + stream.append(markdown_text="Hello!") + stream.stop() + """ + + def __init__( + self, + *, + client: WebClient, + channel_id: Optional[str] = None, + thread_ts: Optional[str] = None, + team_id: Optional[str] = None, + user_id: Optional[str] = None, + ): + self._client = client + self._channel_id = channel_id + self._thread_ts = thread_ts + self._team_id = team_id + self._user_id = user_id + + def chat_stream( + self, + *, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + **kwargs, + ) -> ChatStream: + """Creates a ChatStream with defaults from event context. + + Each call creates a new instance. Create multiple for parallel streams. + + Args: + channel: Channel ID. Defaults to the channel from the event context. + thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. + recipient_team_id: Team ID of the recipient. Defaults to the team from the event context. + recipient_user_id: User ID of the recipient. Defaults to the user from the event context. + **kwargs: Additional arguments passed to ``WebClient.chat_stream()``. + + Returns: + A new ``ChatStream`` instance. + """ + provided = [arg for arg in (channel, thread_ts, recipient_team_id, recipient_user_id) if arg is not None] + if provided and len(provided) < 4: + raise ValueError( + "Either provide all of channel, thread_ts, recipient_team_id, and recipient_user_id, or none of them" + ) + resolved_channel = channel or self._channel_id + resolved_thread_ts = thread_ts or self._thread_ts + if resolved_channel is None: + raise ValueError( + "channel is required: provide it as an argument or ensure channel_id is set in the event context" + ) + if resolved_thread_ts is None: + raise ValueError( + "thread_ts is required: provide it as an argument or ensure thread_ts is set in the event context" + ) + return self._client.chat_stream( + channel=resolved_channel, + thread_ts=resolved_thread_ts, + recipient_team_id=recipient_team_id or self._team_id, + recipient_user_id=recipient_user_id or self._user_id, + **kwargs, + ) diff --git a/slack_bolt/agent/async_agent.py b/slack_bolt/agent/async_agent.py new file mode 100644 index 000000000..425f8dff4 --- /dev/null +++ b/slack_bolt/agent/async_agent.py @@ -0,0 +1,79 @@ +from typing import Optional + +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.async_chat_stream import AsyncChatStream + + +class AsyncBoltAgent: + """Async agent listener argument for building AI-powered Slack agents. + + Experimental: + This API is experimental and may change in future releases. + + @app.event("app_mention") + async def handle_mention(agent): + stream = await agent.chat_stream() + await stream.append(markdown_text="Hello!") + await stream.stop() + """ + + def __init__( + self, + *, + client: AsyncWebClient, + channel_id: Optional[str] = None, + thread_ts: Optional[str] = None, + team_id: Optional[str] = None, + user_id: Optional[str] = None, + ): + self._client = client + self._channel_id = channel_id + self._thread_ts = thread_ts + self._team_id = team_id + self._user_id = user_id + + async def chat_stream( + self, + *, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + recipient_team_id: Optional[str] = None, + recipient_user_id: Optional[str] = None, + **kwargs, + ) -> AsyncChatStream: + """Creates an AsyncChatStream with defaults from event context. + + Each call creates a new instance. Create multiple for parallel streams. + + Args: + channel: Channel ID. Defaults to the channel from the event context. + thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. + recipient_team_id: Team ID of the recipient. Defaults to the team from the event context. + recipient_user_id: User ID of the recipient. Defaults to the user from the event context. + **kwargs: Additional arguments passed to ``AsyncWebClient.chat_stream()``. + + Returns: + A new ``AsyncChatStream`` instance. + """ + provided = [arg for arg in (channel, thread_ts, recipient_team_id, recipient_user_id) if arg is not None] + if provided and len(provided) < 4: + raise ValueError( + "Either provide all of channel, thread_ts, recipient_team_id, and recipient_user_id, or none of them" + ) + resolved_channel = channel or self._channel_id + resolved_thread_ts = thread_ts or self._thread_ts + if resolved_channel is None: + raise ValueError( + "channel is required: provide it as an argument or ensure channel_id is set in the event context" + ) + if resolved_thread_ts is None: + raise ValueError( + "thread_ts is required: provide it as an argument or ensure thread_ts is set in the event context" + ) + return await self._client.chat_stream( + channel=resolved_channel, + thread_ts=resolved_thread_ts, + recipient_team_id=recipient_team_id or self._team_id, + recipient_user_id=recipient_user_id or self._user_id, + **kwargs, + ) diff --git a/slack_bolt/context/async_context.py b/slack_bolt/context/async_context.py index 47eb4744e..3e373e55f 100644 --- a/slack_bolt/context/async_context.py +++ b/slack_bolt/context/async_context.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional from slack_sdk.web.async_client import AsyncWebClient @@ -15,6 +15,9 @@ from slack_bolt.context.set_title.async_set_title import AsyncSetTitle from slack_bolt.util.utils import create_copy +if TYPE_CHECKING: + from slack_bolt.agent.async_agent import AsyncBoltAgent + class AsyncBoltContext(BaseContext): """Context object associated with a request from Slack.""" @@ -187,6 +190,34 @@ async def handle_button_clicks(context): self["fail"] = AsyncFail(client=self.client, function_execution_id=self.function_execution_id) return self["fail"] + @property + def agent(self) -> "AsyncBoltAgent": + """`agent` listener argument for building AI-powered Slack agents. + + Experimental: + This API is experimental and may change in future releases. + + @app.event("app_mention") + async def handle_mention(agent): + stream = await agent.chat_stream() + await stream.append(markdown_text="Hello!") + await stream.stop() + + Returns: + `AsyncBoltAgent` instance + """ + if "agent" not in self: + from slack_bolt.agent.async_agent import AsyncBoltAgent + + self["agent"] = AsyncBoltAgent( + client=self.client, + channel_id=self.channel_id, + thread_ts=self.thread_ts, + team_id=self.team_id, + user_id=self.user_id, + ) + return self["agent"] + @property def set_title(self) -> Optional[AsyncSetTitle]: return self.get("set_title") diff --git a/slack_bolt/context/base_context.py b/slack_bolt/context/base_context.py index 843d5ef60..85105b783 100644 --- a/slack_bolt/context/base_context.py +++ b/slack_bolt/context/base_context.py @@ -38,6 +38,7 @@ class BaseContext(dict): "set_status", "set_title", "set_suggested_prompts", + "agent", ] # Note that these items are not copyable, so when you add new items to this list, # you must modify ThreadListenerRunner/AsyncioListenerRunner's _build_lazy_request method to pass the values. diff --git a/slack_bolt/context/context.py b/slack_bolt/context/context.py index 31edf2891..bbd001482 100644 --- a/slack_bolt/context/context.py +++ b/slack_bolt/context/context.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional from slack_sdk import WebClient @@ -15,6 +15,9 @@ from slack_bolt.context.set_title import SetTitle from slack_bolt.util.utils import create_copy +if TYPE_CHECKING: + from slack_bolt.agent.agent import BoltAgent + class BoltContext(BaseContext): """Context object associated with a request from Slack.""" @@ -188,6 +191,34 @@ def handle_button_clicks(context): self["fail"] = Fail(client=self.client, function_execution_id=self.function_execution_id) return self["fail"] + @property + def agent(self) -> "BoltAgent": + """`agent` listener argument for building AI-powered Slack agents. + + Experimental: + This API is experimental and may change in future releases. + + @app.event("app_mention") + def handle_mention(agent): + stream = agent.chat_stream() + stream.append(markdown_text="Hello!") + stream.stop() + + Returns: + `BoltAgent` instance + """ + if "agent" not in self: + from slack_bolt.agent.agent import BoltAgent + + self["agent"] = BoltAgent( + client=self.client, + channel_id=self.channel_id, + thread_ts=self.thread_ts, + team_id=self.team_id, + user_id=self.user_id, + ) + return self["agent"] + @property def set_title(self) -> Optional[SetTitle]: return self.get("set_title") diff --git a/slack_bolt/kwargs_injection/args.py b/slack_bolt/kwargs_injection/args.py index 1a0ec3ca8..113e39c08 100644 --- a/slack_bolt/kwargs_injection/args.py +++ b/slack_bolt/kwargs_injection/args.py @@ -8,6 +8,7 @@ from slack_bolt.context.fail import Fail from slack_bolt.context.get_thread_context.get_thread_context import GetThreadContext from slack_bolt.context.respond import Respond +from slack_bolt.agent.agent import BoltAgent from slack_bolt.context.save_thread_context import SaveThreadContext from slack_bolt.context.say import Say from slack_bolt.context.set_status import SetStatus @@ -102,6 +103,8 @@ def handle_buttons(args): """`get_thread_context()` utility function for AI Agents & Assistants""" save_thread_context: Optional[SaveThreadContext] """`save_thread_context()` utility function for AI Agents & Assistants""" + agent: Optional[BoltAgent] + """`agent` listener argument for AI Agents & Assistants""" # middleware next: Callable[[], None] """`next()` utility function, which tells the middleware chain that it can continue with the next one""" @@ -135,6 +138,7 @@ def __init__( set_suggested_prompts: Optional[SetSuggestedPrompts] = None, get_thread_context: Optional[GetThreadContext] = None, save_thread_context: Optional[SaveThreadContext] = None, + agent: Optional[BoltAgent] = None, # As this method is not supposed to be invoked by bolt-python users, # the naming conflict with the built-in one affects # only the internals of this method @@ -168,6 +172,7 @@ def __init__( self.set_suggested_prompts = set_suggested_prompts self.get_thread_context = get_thread_context self.save_thread_context = save_thread_context + self.agent = agent self.next: Callable[[], None] = next self.next_: Callable[[], None] = next diff --git a/slack_bolt/kwargs_injection/async_args.py b/slack_bolt/kwargs_injection/async_args.py index 4953f2167..1f1dde024 100644 --- a/slack_bolt/kwargs_injection/async_args.py +++ b/slack_bolt/kwargs_injection/async_args.py @@ -1,6 +1,7 @@ from logging import Logger from typing import Callable, Awaitable, Dict, Any, Optional +from slack_bolt.agent.async_agent import AsyncBoltAgent from slack_bolt.context.ack.async_ack import AsyncAck from slack_bolt.context.async_context import AsyncBoltContext from slack_bolt.context.complete.async_complete import AsyncComplete @@ -101,6 +102,8 @@ async def handle_buttons(args): """`get_thread_context()` utility function for AI Agents & Assistants""" save_thread_context: Optional[AsyncSaveThreadContext] """`save_thread_context()` utility function for AI Agents & Assistants""" + agent: Optional[AsyncBoltAgent] + """`agent` listener argument for AI Agents & Assistants""" # middleware next: Callable[[], Awaitable[None]] """`next()` utility function, which tells the middleware chain that it can continue with the next one""" @@ -134,6 +137,7 @@ def __init__( set_suggested_prompts: Optional[AsyncSetSuggestedPrompts] = None, get_thread_context: Optional[AsyncGetThreadContext] = None, save_thread_context: Optional[AsyncSaveThreadContext] = None, + agent: Optional[AsyncBoltAgent] = None, next: Callable[[], Awaitable[None]], **kwargs, # noqa ): @@ -164,6 +168,7 @@ def __init__( self.set_suggested_prompts = set_suggested_prompts self.get_thread_context = get_thread_context self.save_thread_context = save_thread_context + self.agent = agent self.next: Callable[[], Awaitable[None]] = next self.next_: Callable[[], Awaitable[None]] = next diff --git a/slack_bolt/kwargs_injection/async_utils.py b/slack_bolt/kwargs_injection/async_utils.py index c8870c3cc..35ffacf45 100644 --- a/slack_bolt/kwargs_injection/async_utils.py +++ b/slack_bolt/kwargs_injection/async_utils.py @@ -1,9 +1,11 @@ import inspect import logging +import warnings from typing import Callable, Dict, MutableSequence, Optional, Any from slack_bolt.request.async_request import AsyncBoltRequest from slack_bolt.response import BoltResponse +from slack_bolt.warning import ExperimentalWarning from .async_args import AsyncArgs from slack_bolt.request.payload_utils import ( to_options, @@ -29,7 +31,7 @@ def build_async_required_kwargs( error: Optional[Exception] = None, # for error handlers next_keys_required: bool = True, # False for listeners / middleware / error handlers ) -> Dict[str, Any]: - all_available_args = { + all_available_args: Dict[str, Any] = { "logger": logger, "client": request.context.client, "req": request, @@ -83,6 +85,16 @@ def build_async_required_kwargs( if k not in all_available_args: all_available_args[k] = v + # Defer agent creation to avoid constructing AsyncBoltAgent on every request + if "agent" in required_arg_names or "args" in required_arg_names: + all_available_args["agent"] = request.context.agent + if "agent" in required_arg_names: + warnings.warn( + "The agent listener argument is experimental and may change in future versions.", + category=ExperimentalWarning, + stacklevel=2, # Point to the caller, not this internal helper + ) + if len(required_arg_names) > 0: # To support instance/class methods in a class for listeners/middleware, # check if the first argument is either self or cls @@ -102,7 +114,7 @@ def build_async_required_kwargs( for name in required_arg_names: if name == "args": if isinstance(request, AsyncBoltRequest): - kwargs[name] = AsyncArgs(**all_available_args) # type: ignore[arg-type] + kwargs[name] = AsyncArgs(**all_available_args) else: logger.warning(f"Unknown Request object type detected ({type(request)})") diff --git a/slack_bolt/kwargs_injection/utils.py b/slack_bolt/kwargs_injection/utils.py index c1909c67a..8f9fc9886 100644 --- a/slack_bolt/kwargs_injection/utils.py +++ b/slack_bolt/kwargs_injection/utils.py @@ -1,9 +1,11 @@ import inspect import logging +import warnings from typing import Callable, Dict, MutableSequence, Optional, Any from slack_bolt.request import BoltRequest from slack_bolt.response import BoltResponse +from slack_bolt.warning import ExperimentalWarning from .args import Args from slack_bolt.request.payload_utils import ( to_options, @@ -29,7 +31,7 @@ def build_required_kwargs( error: Optional[Exception] = None, # for error handlers next_keys_required: bool = True, # False for listeners / middleware / error handlers ) -> Dict[str, Any]: - all_available_args = { + all_available_args: Dict[str, Any] = { "logger": logger, "client": request.context.client, "req": request, @@ -82,6 +84,16 @@ def build_required_kwargs( if k not in all_available_args: all_available_args[k] = v + # Defer agent creation to avoid constructing BoltAgent on every request + if "agent" in required_arg_names or "args" in required_arg_names: + all_available_args["agent"] = request.context.agent + if "agent" in required_arg_names: + warnings.warn( + "The agent listener argument is experimental and may change in future versions.", + category=ExperimentalWarning, + stacklevel=2, # Point to the caller, not this internal helper + ) + if len(required_arg_names) > 0: # To support instance/class methods in a class for listeners/middleware, # check if the first argument is either self or cls @@ -101,7 +113,7 @@ def build_required_kwargs( for name in required_arg_names: if name == "args": if isinstance(request, BoltRequest): - kwargs[name] = Args(**all_available_args) # type: ignore[arg-type] + kwargs[name] = Args(**all_available_args) else: logger.warning(f"Unknown Request object type detected ({type(request)})") diff --git a/slack_bolt/warning/__init__.py b/slack_bolt/warning/__init__.py new file mode 100644 index 000000000..4991f4cd9 --- /dev/null +++ b/slack_bolt/warning/__init__.py @@ -0,0 +1,7 @@ +"""Bolt specific warning types.""" + + +class ExperimentalWarning(FutureWarning): + """Warning for features that are still in experimental phase.""" + + pass diff --git a/tests/scenario_tests/test_events_agent.py b/tests/scenario_tests/test_events_agent.py new file mode 100644 index 000000000..636ade669 --- /dev/null +++ b/tests/scenario_tests/test_events_agent.py @@ -0,0 +1,189 @@ +import json +from time import sleep + +import pytest +from slack_sdk.web import WebClient + +from slack_bolt import App, BoltRequest, BoltContext, BoltAgent +from slack_bolt.agent.agent import BoltAgent as BoltAgentDirect +from slack_bolt.warning import ExperimentalWarning +from tests.mock_web_api_server import ( + setup_mock_web_api_server, + cleanup_mock_web_api_server, +) +from tests.utils import remove_os_env_temporarily, restore_os_env + + +class TestEventsAgent: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = WebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + def setup_method(self): + self.old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server(self) + + def teardown_method(self): + cleanup_mock_web_api_server(self) + restore_os_env(self.old_os_env) + + def test_agent_injected_for_app_mention(self): + app = App(client=self.web_client) + + state = {"called": False} + + def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.event("app_mention") + def handle_mention(agent: BoltAgent, context: BoltContext): + assert agent is not None + assert isinstance(agent, BoltAgentDirect) + assert context.channel_id == "C111" + state["called"] = True + + request = BoltRequest(body=app_mention_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called() + + def test_agent_available_in_action_listener(self): + app = App(client=self.web_client) + + state = {"called": False} + + def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.action("test_action") + def handle_action(ack, agent: BoltAgent): + ack() + assert agent is not None + assert isinstance(agent, BoltAgentDirect) + state["called"] = True + + request = BoltRequest(body=json.dumps(action_event_body), mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called() + + def test_agent_accessible_via_context(self): + app = App(client=self.web_client) + + state = {"called": False} + + def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.event("app_mention") + def handle_mention(context: BoltContext): + agent = context.agent + assert agent is not None + assert isinstance(agent, BoltAgentDirect) + # Verify the same instance is returned on subsequent access + assert context.agent is agent + state["called"] = True + + request = BoltRequest(body=app_mention_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_target_called() + + def test_agent_kwarg_emits_experimental_warning(self): + app = App(client=self.web_client) + + state = {"called": False} + + def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.event("app_mention") + def handle_mention(agent: BoltAgent): + state["called"] = True + + request = BoltRequest(body=app_mention_event_body, mode="socket_mode") + with pytest.warns(ExperimentalWarning, match="agent listener argument is experimental"): + response = app.dispatch(request) + assert response.status == 200 + assert_target_called() + + +# ---- Test event bodies ---- + + +def build_payload(event: dict) -> dict: + return { + "token": "verification_token", + "team_id": "T111", + "enterprise_id": "E111", + "api_app_id": "A111", + "event": event, + "type": "event_callback", + "event_id": "Ev111", + "event_time": 1599616881, + "authorizations": [ + { + "enterprise_id": "E111", + "team_id": "T111", + "user_id": "W111", + "is_bot": True, + "is_enterprise_install": False, + } + ], + } + + +app_mention_event_body = build_payload( + { + "type": "app_mention", + "user": "W222", + "text": "<@W111> hello", + "ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.123456", + } +) + +action_event_body = { + "type": "block_actions", + "user": {"id": "W222", "username": "test_user", "name": "test_user", "team_id": "T111"}, + "api_app_id": "A111", + "token": "verification_token", + "container": {"type": "message", "message_ts": "1234567890.123456", "channel_id": "C111", "is_ephemeral": False}, + "channel": {"id": "C111", "name": "test-channel"}, + "team": {"id": "T111", "domain": "test"}, + "enterprise": {"id": "E111", "name": "test"}, + "trigger_id": "111.222.xxx", + "actions": [ + { + "type": "button", + "block_id": "b", + "action_id": "test_action", + "text": {"type": "plain_text", "text": "Button"}, + "action_ts": "1234567890.123456", + } + ], +} diff --git a/tests/scenario_tests_async/test_events_agent.py b/tests/scenario_tests_async/test_events_agent.py new file mode 100644 index 000000000..a665d786b --- /dev/null +++ b/tests/scenario_tests_async/test_events_agent.py @@ -0,0 +1,197 @@ +import asyncio +import json + +import pytest +from slack_sdk.web.async_client import AsyncWebClient + +from slack_bolt.agent.async_agent import AsyncBoltAgent +from slack_bolt.app.async_app import AsyncApp +from slack_bolt.context.async_context import AsyncBoltContext +from slack_bolt.request.async_request import AsyncBoltRequest +from slack_bolt.warning import ExperimentalWarning +from tests.mock_web_api_server import ( + cleanup_mock_web_api_server_async, + setup_mock_web_api_server_async, +) +from tests.utils import remove_os_env_temporarily, restore_os_env + + +class TestAsyncEventsAgent: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = AsyncWebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + @pytest.fixture(scope="function", autouse=True) + def setup_teardown(self): + old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server_async(self) + try: + yield + finally: + cleanup_mock_web_api_server_async(self) + restore_os_env(old_os_env) + + @pytest.mark.asyncio + async def test_agent_injected_for_app_mention(self): + app = AsyncApp(client=self.web_client) + + state = {"called": False} + + async def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + await asyncio.sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.event("app_mention") + async def handle_mention(agent: AsyncBoltAgent, context: AsyncBoltContext): + assert agent is not None + assert isinstance(agent, AsyncBoltAgent) + assert context.channel_id == "C111" + state["called"] = True + + request = AsyncBoltRequest(body=app_mention_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called() + + @pytest.mark.asyncio + async def test_agent_available_in_action_listener(self): + app = AsyncApp(client=self.web_client) + + state = {"called": False} + + async def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + await asyncio.sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.action("test_action") + async def handle_action(ack, agent: AsyncBoltAgent): + await ack() + assert agent is not None + assert isinstance(agent, AsyncBoltAgent) + state["called"] = True + + request = AsyncBoltRequest(body=json.dumps(action_event_body), mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called() + + @pytest.mark.asyncio + async def test_agent_accessible_via_context(self): + app = AsyncApp(client=self.web_client) + + state = {"called": False} + + async def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + await asyncio.sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.event("app_mention") + async def handle_mention(context: AsyncBoltContext): + agent = context.agent + assert agent is not None + assert isinstance(agent, AsyncBoltAgent) + # Verify the same instance is returned on subsequent access + assert context.agent is agent + state["called"] = True + + request = AsyncBoltRequest(body=app_mention_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called() + + @pytest.mark.asyncio + async def test_agent_kwarg_emits_experimental_warning(self): + app = AsyncApp(client=self.web_client) + + state = {"called": False} + + async def assert_target_called(): + count = 0 + while state["called"] is False and count < 20: + await asyncio.sleep(0.1) + count += 1 + assert state["called"] is True + state["called"] = False + + @app.event("app_mention") + async def handle_mention(agent: AsyncBoltAgent): + state["called"] = True + + request = AsyncBoltRequest(body=app_mention_event_body, mode="socket_mode") + with pytest.warns(ExperimentalWarning, match="agent listener argument is experimental"): + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_target_called() + + +# ---- Test event bodies ---- + + +def build_payload(event: dict) -> dict: + return { + "token": "verification_token", + "team_id": "T111", + "enterprise_id": "E111", + "api_app_id": "A111", + "event": event, + "type": "event_callback", + "event_id": "Ev111", + "event_time": 1599616881, + "authorizations": [ + { + "enterprise_id": "E111", + "team_id": "T111", + "user_id": "W111", + "is_bot": True, + "is_enterprise_install": False, + } + ], + } + + +app_mention_event_body = build_payload( + { + "type": "app_mention", + "user": "W222", + "text": "<@W111> hello", + "ts": "1234567890.123456", + "channel": "C111", + "event_ts": "1234567890.123456", + } +) + +action_event_body = { + "type": "block_actions", + "user": {"id": "W222", "username": "test_user", "name": "test_user", "team_id": "T111"}, + "api_app_id": "A111", + "token": "verification_token", + "container": {"type": "message", "message_ts": "1234567890.123456", "channel_id": "C111", "is_ephemeral": False}, + "channel": {"id": "C111", "name": "test-channel"}, + "team": {"id": "T111", "domain": "test"}, + "enterprise": {"id": "E111", "name": "test"}, + "trigger_id": "111.222.xxx", + "actions": [ + { + "type": "button", + "block_id": "b", + "action_id": "test_action", + "text": {"type": "plain_text", "text": "Button"}, + "action_ts": "1234567890.123456", + } + ], +} diff --git a/tests/slack_bolt/agent/__init__.py b/tests/slack_bolt/agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/slack_bolt/agent/test_agent.py b/tests/slack_bolt/agent/test_agent.py new file mode 100644 index 000000000..00e998379 --- /dev/null +++ b/tests/slack_bolt/agent/test_agent.py @@ -0,0 +1,103 @@ +from unittest.mock import MagicMock + +import pytest +from slack_sdk.web import WebClient +from slack_sdk.web.chat_stream import ChatStream + +from slack_bolt.agent.agent import BoltAgent + + +class TestBoltAgent: + def test_chat_stream_uses_context_defaults(self): + """BoltAgent.chat_stream() passes context defaults to WebClient.chat_stream().""" + client = MagicMock(spec=WebClient) + client.chat_stream.return_value = MagicMock(spec=ChatStream) + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = agent.chat_stream() + + client.chat_stream.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + ) + assert stream is not None + + def test_chat_stream_overrides_context_defaults(self): + """Explicit kwargs to chat_stream() override context defaults.""" + client = MagicMock(spec=WebClient) + client.chat_stream.return_value = MagicMock(spec=ChatStream) + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = agent.chat_stream( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + + client.chat_stream.assert_called_once_with( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + assert stream is not None + + def test_chat_stream_rejects_partial_overrides(self): + """Passing only some of the four context args raises ValueError.""" + client = MagicMock(spec=WebClient) + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(ValueError, match="Either provide all of"): + agent.chat_stream(channel="C999") + + def test_chat_stream_passes_extra_kwargs(self): + """Extra kwargs are forwarded to WebClient.chat_stream().""" + client = MagicMock(spec=WebClient) + client.chat_stream.return_value = MagicMock(spec=ChatStream) + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.chat_stream(buffer_size=512) + + client.chat_stream.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + buffer_size=512, + ) + + def test_import_from_slack_bolt(self): + from slack_bolt import BoltAgent as ImportedBoltAgent + + assert ImportedBoltAgent is BoltAgent + + def test_import_from_agent_module(self): + from slack_bolt.agent import BoltAgent as ImportedBoltAgent + + assert ImportedBoltAgent is BoltAgent diff --git a/tests/slack_bolt_async/agent/__init__.py b/tests/slack_bolt_async/agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/slack_bolt_async/agent/test_async_agent.py b/tests/slack_bolt_async/agent/test_async_agent.py new file mode 100644 index 000000000..02251fa4b --- /dev/null +++ b/tests/slack_bolt_async/agent/test_async_agent.py @@ -0,0 +1,114 @@ +from unittest.mock import MagicMock + +import pytest +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.async_chat_stream import AsyncChatStream + +from slack_bolt.agent.async_agent import AsyncBoltAgent + + +def _make_async_chat_stream_mock(): + mock_stream = MagicMock(spec=AsyncChatStream) + call_tracker = MagicMock() + + async def fake_chat_stream(**kwargs): + call_tracker(**kwargs) + return mock_stream + + return fake_chat_stream, call_tracker, mock_stream + + +class TestAsyncBoltAgent: + @pytest.mark.asyncio + async def test_chat_stream_uses_context_defaults(self): + """AsyncBoltAgent.chat_stream() passes context defaults to AsyncWebClient.chat_stream().""" + client = MagicMock(spec=AsyncWebClient) + client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = await agent.chat_stream() + + call_tracker.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + ) + assert stream is not None + + @pytest.mark.asyncio + async def test_chat_stream_overrides_context_defaults(self): + """Explicit kwargs to chat_stream() override context defaults.""" + client = MagicMock(spec=AsyncWebClient) + client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + stream = await agent.chat_stream( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + + call_tracker.assert_called_once_with( + channel="C999", + thread_ts="9999999999.999999", + recipient_team_id="T999", + recipient_user_id="U999", + ) + assert stream is not None + + @pytest.mark.asyncio + async def test_chat_stream_rejects_partial_overrides(self): + """Passing only some of the four context args raises ValueError.""" + client = MagicMock(spec=AsyncWebClient) + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(ValueError, match="Either provide all of"): + await agent.chat_stream(channel="C999") + + @pytest.mark.asyncio + async def test_chat_stream_passes_extra_kwargs(self): + """Extra kwargs are forwarded to AsyncWebClient.chat_stream().""" + client = MagicMock(spec=AsyncWebClient) + client.chat_stream, call_tracker, _ = _make_async_chat_stream_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.chat_stream(buffer_size=512) + + call_tracker.assert_called_once_with( + channel="C111", + thread_ts="1234567890.123456", + recipient_team_id="T111", + recipient_user_id="W222", + buffer_size=512, + ) + + @pytest.mark.asyncio + async def test_import_from_agent_module(self): + from slack_bolt.agent.async_agent import AsyncBoltAgent as ImportedAsyncBoltAgent + + assert ImportedAsyncBoltAgent is AsyncBoltAgent