Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion slack_bolt/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional, Sequence, Union

from slack_sdk import WebClient
from slack_sdk.web import SlackResponse
Expand Down Expand Up @@ -101,3 +101,40 @@ def set_status(
loading_messages=loading_messages,
**kwargs,
)

def set_suggested_prompts(
self,
*,
prompts: Sequence[Union[str, Dict[str, str]]],
title: Optional[str] = None,
channel: Optional[str] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 note: I'm noticing this is different than the expected API value!

🔗 https://docs.slack.dev/reference/methods/assistant.threads.setSuggestedPrompts/

thread_ts: Optional[str] = None,
**kwargs,
) -> SlackResponse:
"""Sets suggested prompts for an assistant thread.

Args:
prompts: A sequence of prompts. Each prompt can be either a string
(used as both title and message) or a dict with 'title' and 'message' keys.
title: Optional title for the suggested prompts section.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``WebClient.assistant_threads_setSuggestedPrompts()``.

Returns:
``SlackResponse`` from the API call.
"""
prompts_arg: List[Dict[str, str]] = []
for prompt in prompts:
if isinstance(prompt, str):
prompts_arg.append({"title": prompt, "message": prompt})
else:
prompts_arg.append(prompt)

return self._client.assistant_threads_setSuggestedPrompts(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
prompts=prompts_arg,
title=title,
**kwargs,
)
39 changes: 38 additions & 1 deletion slack_bolt/agent/async_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional, Sequence, Union

from slack_sdk.web.async_client import AsyncSlackResponse, AsyncWebClient
from slack_sdk.web.async_chat_stream import AsyncChatStream
Expand Down Expand Up @@ -97,3 +97,40 @@ async def set_status(
loading_messages=loading_messages,
**kwargs,
)

async def set_suggested_prompts(
self,
*,
prompts: Sequence[Union[str, Dict[str, str]]],
title: Optional[str] = None,
channel: Optional[str] = None,
thread_ts: Optional[str] = None,
**kwargs,
) -> AsyncSlackResponse:
"""Sets suggested prompts for an assistant thread.

Args:
prompts: A sequence of prompts. Each prompt can be either a string
(used as both title and message) or a dict with 'title' and 'message' keys.
title: Optional title for the suggested prompts section.
channel: Channel ID. Defaults to the channel from the event context.
thread_ts: Thread timestamp. Defaults to the thread_ts from the event context.
**kwargs: Additional arguments passed to ``AsyncWebClient.assistant_threads_setSuggestedPrompts()``.

Returns:
``AsyncSlackResponse`` from the API call.
"""
prompts_arg: List[Dict[str, str]] = []
for prompt in prompts:
if isinstance(prompt, str):
prompts_arg.append({"title": prompt, "message": prompt})
else:
prompts_arg.append(prompt)

return await self._client.assistant_threads_setSuggestedPrompts(
channel_id=channel or self._channel_id, # type: ignore[arg-type]
thread_ts=thread_ts or self._thread_ts, # type: ignore[arg-type]
prompts=prompts_arg,
title=title,
**kwargs,
)
112 changes: 112 additions & 0 deletions tests/slack_bolt/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,118 @@ def test_set_status_requires_status(self):
with pytest.raises(TypeError):
agent.set_status()

def test_set_suggested_prompts_uses_context_defaults(self):
"""BoltAgent.set_suggested_prompts() passes context defaults to WebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"])

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "What can you do?", "message": "What can you do?"},
{"title": "Help me write code", "message": "Help me write code"},
],
title=None,
)

def test_set_suggested_prompts_with_dict_prompts(self):
"""BoltAgent.set_suggested_prompts() accepts dict prompts with title and message."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

def test_set_suggested_prompts_overrides_context_defaults(self):
"""Explicit channel/thread_ts override context defaults."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(
prompts=["Hello"],
channel="C999",
thread_ts="9999999999.999999",
)

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C999",
thread_ts="9999999999.999999",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
)

def test_set_suggested_prompts_passes_extra_kwargs(self):
"""Extra kwargs are forwarded to WebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=WebClient)
client.assistant_threads_setSuggestedPrompts.return_value = MagicMock()

agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override")

client.assistant_threads_setSuggestedPrompts.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
token="xoxb-override",
)

def test_set_suggested_prompts_requires_prompts(self):
"""set_suggested_prompts() raises TypeError when prompts is not provided."""
client = MagicMock(spec=WebClient)
agent = BoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
with pytest.raises(TypeError):
agent.set_suggested_prompts()

def test_import_from_slack_bolt(self):
from slack_bolt import BoltAgent as ImportedBoltAgent

Expand Down
117 changes: 117 additions & 0 deletions tests/slack_bolt_async/agent/test_async_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,123 @@ async def test_set_status_requires_status(self):
with pytest.raises(TypeError):
await agent.set_status()

@pytest.mark.asyncio
async def test_set_suggested_prompts_uses_context_defaults(self):
"""AsyncBoltAgent.set_suggested_prompts() passes context defaults to AsyncWebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_suggested_prompts(prompts=["What can you do?", "Help me write code"])

call_tracker.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "What can you do?", "message": "What can you do?"},
{"title": "Help me write code", "message": "Help me write code"},
],
title=None,
)

@pytest.mark.asyncio
async def test_set_suggested_prompts_with_dict_prompts(self):
"""AsyncBoltAgent.set_suggested_prompts() accepts dict prompts with title and message."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_suggested_prompts(
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

call_tracker.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[
{"title": "Short title", "message": "A much longer message for this prompt"},
],
title="Suggestions",
)

@pytest.mark.asyncio
async def test_set_suggested_prompts_overrides_context_defaults(self):
"""Explicit channel/thread_ts override context defaults."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_suggested_prompts(
prompts=["Hello"],
channel="C999",
thread_ts="9999999999.999999",
)

call_tracker.assert_called_once_with(
channel_id="C999",
thread_ts="9999999999.999999",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
)

@pytest.mark.asyncio
async def test_set_suggested_prompts_passes_extra_kwargs(self):
"""Extra kwargs are forwarded to AsyncWebClient.assistant_threads_setSuggestedPrompts()."""
client = MagicMock(spec=AsyncWebClient)
client.assistant_threads_setSuggestedPrompts, call_tracker, _ = _make_async_api_mock()

agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
await agent.set_suggested_prompts(prompts=["Hello"], token="xoxb-override")

call_tracker.assert_called_once_with(
channel_id="C111",
thread_ts="1234567890.123456",
prompts=[{"title": "Hello", "message": "Hello"}],
title=None,
token="xoxb-override",
)

@pytest.mark.asyncio
async def test_set_suggested_prompts_requires_prompts(self):
"""set_suggested_prompts() raises TypeError when prompts is not provided."""
client = MagicMock(spec=AsyncWebClient)
agent = AsyncBoltAgent(
client=client,
channel_id="C111",
thread_ts="1234567890.123456",
team_id="T111",
user_id="W222",
)
with pytest.raises(TypeError):
await agent.set_suggested_prompts()

@pytest.mark.asyncio
async def test_import_from_agent_module(self):
from slack_bolt.agent.async_agent import AsyncBoltAgent as ImportedAsyncBoltAgent
Expand Down