diff --git a/slack_bolt/agent/agent.py b/slack_bolt/agent/agent.py index 3663b245b..056dba986 100644 --- a/slack_bolt/agent/agent.py +++ b/slack_bolt/agent/agent.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional, Sequence, Union from slack_sdk import WebClient from slack_sdk.web import SlackResponse @@ -101,3 +101,40 @@ def set_status( loading_messages=loading_messages, **kwargs, ) + + def set_suggested_prompts( + self, + *, + prompts: Sequence[Union[str, Dict[str, str]]], + title: Optional[str] = None, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + **kwargs, + ) -> SlackResponse: + """Sets suggested prompts for an assistant thread. + + Args: + prompts: A sequence of prompts. Each prompt can be either a string + (used as both title and message) or a dict with 'title' and 'message' keys. + title: Optional title for the suggested prompts section. + channel: Channel ID. Defaults to the channel from the event context. + thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. + **kwargs: Additional arguments passed to ``WebClient.assistant_threads_setSuggestedPrompts()``. + + Returns: + ``SlackResponse`` from the API call. + """ + prompts_arg: List[Dict[str, str]] = [] + for prompt in prompts: + if isinstance(prompt, str): + prompts_arg.append({"title": prompt, "message": prompt}) + else: + prompts_arg.append(prompt) + + return self._client.assistant_threads_setSuggestedPrompts( + channel_id=channel or self._channel_id, # type: ignore[arg-type] + thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type] + prompts=prompts_arg, + title=title, + **kwargs, + ) diff --git a/slack_bolt/agent/async_agent.py b/slack_bolt/agent/async_agent.py index 5b86533e6..5630e1b81 100644 --- a/slack_bolt/agent/async_agent.py +++ b/slack_bolt/agent/async_agent.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Dict, List, Optional, Sequence, Union from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient from slack_sdk.web.async_chat_stream import AsyncChatStream @@ -97,3 +97,40 @@ async def set_status( loading_messages=loading_messages, **kwargs, ) + + async def set_suggested_prompts( + self, + *, + prompts: Sequence[Union[str, Dict[str, str]]], + title: Optional[str] = None, + channel: Optional[str] = None, + thread_ts: Optional[str] = None, + **kwargs, + ) -> AsyncSlackResponse: + """Sets suggested prompts for an assistant thread. + + Args: + prompts: A sequence of prompts. Each prompt can be either a string + (used as both title and message) or a dict with 'title' and 'message' keys. + title: Optional title for the suggested prompts section. + channel: Channel ID. Defaults to the channel from the event context. + thread_ts: Thread timestamp. Defaults to the thread_ts from the event context. + **kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setSuggestedPrompts()``. + + Returns: + ``AsyncSlackResponse`` from the API call. + """ + prompts_arg: List[Dict[str, str]] = [] + for prompt in prompts: + if isinstance(prompt, str): + prompts_arg.append({"title": prompt, "message": prompt}) + else: + prompts_arg.append(prompt) + + return await self._client.assistant_threads_setSuggestedPrompts( + channel_id=channel or self._channel_id, # type: ignore[arg-type] + thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type] + prompts=prompts_arg, + title=title, + **kwargs, + ) diff --git a/tests/slack_bolt/agent/test_agent.py b/tests/slack_bolt/agent/test_agent.py index 7dad481b0..1d14eda06 100644 --- a/tests/slack_bolt/agent/test_agent.py +++ b/tests/slack_bolt/agent/test_agent.py @@ -197,6 +197,118 @@ def test_set_status_requires_status(self): with pytest.raises(TypeError): agent.set_status() + def test_set_suggested_prompts_uses_context_defaults(self): + """BoltAgent.set_suggested_prompts() passes context defaults to WebClient.assistant_threads_setSuggestedPrompts().""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"]) + + client.assistant_threads_setSuggestedPrompts.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + prompts=[ + {"title": "What can you do?", "message": "What can you do?"}, + {"title": "Help me write code", "message": "Help me write code"}, + ], + title=None, + ) + + def test_set_suggested_prompts_with_dict_prompts(self): + """BoltAgent.set_suggested_prompts() accepts dict prompts with title and message.""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_suggested_prompts( + prompts=[ + {"title": "Short title", "message": "A much longer message for this prompt"}, + ], + title="Suggestions", + ) + + client.assistant_threads_setSuggestedPrompts.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + prompts=[ + {"title": "Short title", "message": "A much longer message for this prompt"}, + ], + title="Suggestions", + ) + + def test_set_suggested_prompts_overrides_context_defaults(self): + """Explicit channel/thread_ts override context defaults.""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_suggested_prompts( + prompts=["Hello"], + channel="C999", + thread_ts="9999999999.999999", + ) + + client.assistant_threads_setSuggestedPrompts.assert_called_once_with( + channel_id="C999", + thread_ts="9999999999.999999", + prompts=[{"title": "Hello", "message": "Hello"}], + title=None, + ) + + def test_set_suggested_prompts_passes_extra_kwargs(self): + """Extra kwargs are forwarded to WebClient.assistant_threads_setSuggestedPrompts().""" + client = MagicMock(spec=WebClient) + client.assistant_threads_setSuggestedPrompts.return_value = MagicMock() + + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override") + + client.assistant_threads_setSuggestedPrompts.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + prompts=[{"title": "Hello", "message": "Hello"}], + title=None, + token="xoxb-override", + ) + + def test_set_suggested_prompts_requires_prompts(self): + """set_suggested_prompts() raises TypeError when prompts is not provided.""" + client = MagicMock(spec=WebClient) + agent = BoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(TypeError): + agent.set_suggested_prompts() + def test_import_from_slack_bolt(self): from slack_bolt import BoltAgent as ImportedBoltAgent diff --git a/tests/slack_bolt_async/agent/test_async_agent.py b/tests/slack_bolt_async/agent/test_async_agent.py index 8e4c4d5c8..b934bbaeb 100644 --- a/tests/slack_bolt_async/agent/test_async_agent.py +++ b/tests/slack_bolt_async/agent/test_async_agent.py @@ -228,6 +228,123 @@ async def test_set_status_requires_status(self): with pytest.raises(TypeError): await agent.set_status() + @pytest.mark.asyncio + async def test_set_suggested_prompts_uses_context_defaults(self): + """AsyncBoltAgent.set_suggested_prompts() passes context defaults to AsyncWebClient.assistant_threads_setSuggestedPrompts().""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"]) + + call_tracker.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + prompts=[ + {"title": "What can you do?", "message": "What can you do?"}, + {"title": "Help me write code", "message": "Help me write code"}, + ], + title=None, + ) + + @pytest.mark.asyncio + async def test_set_suggested_prompts_with_dict_prompts(self): + """AsyncBoltAgent.set_suggested_prompts() accepts dict prompts with title and message.""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_suggested_prompts( + prompts=[ + {"title": "Short title", "message": "A much longer message for this prompt"}, + ], + title="Suggestions", + ) + + call_tracker.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + prompts=[ + {"title": "Short title", "message": "A much longer message for this prompt"}, + ], + title="Suggestions", + ) + + @pytest.mark.asyncio + async def test_set_suggested_prompts_overrides_context_defaults(self): + """Explicit channel/thread_ts override context defaults.""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_suggested_prompts( + prompts=["Hello"], + channel="C999", + thread_ts="9999999999.999999", + ) + + call_tracker.assert_called_once_with( + channel_id="C999", + thread_ts="9999999999.999999", + prompts=[{"title": "Hello", "message": "Hello"}], + title=None, + ) + + @pytest.mark.asyncio + async def test_set_suggested_prompts_passes_extra_kwargs(self): + """Extra kwargs are forwarded to AsyncWebClient.assistant_threads_setSuggestedPrompts().""" + client = MagicMock(spec=AsyncWebClient) + client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock() + + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + await agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override") + + call_tracker.assert_called_once_with( + channel_id="C111", + thread_ts="1234567890.123456", + prompts=[{"title": "Hello", "message": "Hello"}], + title=None, + token="xoxb-override", + ) + + @pytest.mark.asyncio + async def test_set_suggested_prompts_requires_prompts(self): + """set_suggested_prompts() raises TypeError when prompts is not provided.""" + client = MagicMock(spec=AsyncWebClient) + agent = AsyncBoltAgent( + client=client, + channel_id="C111", + thread_ts="1234567890.123456", + team_id="T111", + user_id="W222", + ) + with pytest.raises(TypeError): + await agent.set_suggested_prompts() + @pytest.mark.asyncio async def test_import_from_agent_module(self): from slack_bolt.agent.async_agent import AsyncBoltAgent as ImportedAsyncBoltAgent