Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/apps/slack/MANIFEST.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ settings:
- app_mention
- member_joined_channel
- message.channels
- message.im
- team_join
interactivity:
is_enabled: true
Expand Down
54 changes: 54 additions & 0 deletions backend/apps/slack/common/handlers/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from apps.ai.agent.tools.rag.rag_tool import RagTool
from apps.slack.blocks import markdown
from apps.slack.constants import CONVERSATION_CONTEXT_LIMIT
from apps.slack.models import Conversation, Workspace

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,6 +48,58 @@ def process_ai_query(query: str) -> str | None:
return rag_tool.query(question=query)


def get_dm_blocks(query: str, workspace_id: str, channel_id: str) -> list[dict]:
"""Get AI response blocks for DM with conversation context.

Args:
query (str): The user's question.
workspace_id (str): Slack workspace ID.
channel_id (str): Slack channel ID for the DM.

Returns:
list: A list of Slack blocks representing the AI response.

"""
ai_response = process_dm_ai_query(query.strip(), workspace_id, channel_id)

if ai_response:
return [markdown(ai_response)]
return get_error_blocks()


def process_dm_ai_query(query: str, workspace_id: str, channel_id: str) -> str | None:
"""Process the AI query with DM conversation context.

Args:
query (str): The user's question.
workspace_id (str): Slack workspace ID.
channel_id (str): Slack channel ID for the DM.

Returns:
str | None: The AI response or None if error occurred.

"""
workspace = Workspace.objects.get(slack_workspace_id=workspace_id)
conversation = Conversation.objects.get(slack_channel_id=channel_id, workspace=workspace)

context = conversation.get_context(conversation_context_limit=CONVERSATION_CONTEXT_LIMIT)

rag_tool = RagTool(
chat_model="gpt-4o",
embedding_model="text-embedding-3-small",
)

if context:
enhanced_query = f"Conversation context:\n{context}\n\nCurrent question: {query}"
else:
enhanced_query = query

response = rag_tool.query(question=enhanced_query)
conversation.add_to_context(query, response)

return response


def get_error_blocks() -> list[dict]:
"""Get error response blocks.

Expand Down
1 change: 1 addition & 0 deletions backend/apps/slack/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from apps.common.constants import NL

CONVERSATION_CONTEXT_LIMIT = 20
NEST_BOT_NAME = "NestBot"

OWASP_APPSEC_CHANNEL_ID = "#C0F7D6DFH"
Expand Down
69 changes: 59 additions & 10 deletions backend/apps/slack/events/message_posted.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
"""Slack message event template."""
"""Slack message event handler for OWASP NestBot."""

import logging
from datetime import timedelta

import django_rq

from apps.ai.common.constants import QUEUE_RESPONSE_TIME_MINUTES
from apps.slack.common.handlers.ai import get_dm_blocks
from apps.slack.common.question_detector import QuestionDetector
from apps.slack.events.event import EventBase
from apps.slack.models import Conversation, Member, Message
from apps.slack.models import Conversation, Member, Message, Workspace
from apps.slack.services.message_auto_reply import generate_ai_reply_if_unanswered

logger = logging.getLogger(__name__)


class MessagePosted(EventBase):
"""Handles new messages posted in channels."""
"""Handles new messages posted in channels or direct messages."""

event_type = "message"

Expand All @@ -24,25 +25,30 @@ def __init__(self):
self.question_detector = QuestionDetector()

def handle_event(self, event, client):
"""Handle an incoming message event."""
"""Handle incoming Slack message events."""
if event.get("subtype") or event.get("bot_id"):
logger.info("Ignored message due to subtype, bot_id, or thread_ts.")
logger.info("Ignored message due to subtype or bot_id.")
return

channel_id = event.get("channel")
user_id = event.get("user")
text = event.get("text", "")
channel_type = event.get("channel_type")

if channel_type == "im":
self.handle_dm(event, client, channel_id, user_id, text)
return

if event.get("thread_ts"):
try:
Message.objects.filter(
slack_message_id=event.get("thread_ts"),
conversation__slack_channel_id=event.get("channel"),
conversation__slack_channel_id=channel_id,
).update(has_replies=True)
except Message.DoesNotExist:
logger.warning("Thread message not found.")
return

channel_id = event.get("channel")
user_id = event.get("user")
text = event.get("text", "")

try:
conversation = Conversation.objects.get(
slack_channel_id=channel_id,
Expand Down Expand Up @@ -71,3 +77,46 @@ def handle_event(self, event, client):
generate_ai_reply_if_unanswered,
message.id,
)

def handle_dm(self, event, client, channel_id, user_id, text):
"""Handle direct messages with NestBot (DMs)."""
workspace_id = event.get("team")
channel_info = client.conversations_info(channel=channel_id)

try:
workspace = Workspace.objects.get(slack_workspace_id=workspace_id)
except Workspace.DoesNotExist:
logger.exception("Workspace not found for DM.")
return

Conversation.update_data(channel_info["channel"], workspace)

try:
Member.objects.get(slack_user_id=user_id, workspace=workspace)
except Member.DoesNotExist:
user_info = client.users_info(user=user_id)
Member.update_data(user_info["user"], workspace, save=True)
logger.info("Created new member for DM")

thread_ts = event.get("thread_ts")

try:
response_blocks = get_dm_blocks(text, workspace_id, channel_id)
if response_blocks:
client.chat_postMessage(
channel=channel_id,
blocks=response_blocks,
text=text,
thread_ts=thread_ts,
)

except Exception:
logger.exception("Error processing DM")
client.chat_postMessage(
channel=channel_id,
text=(
"I'm sorry, I'm having trouble processing your message right now. "
"Please try again later."
),
thread_ts=thread_ts,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Django 5.2.6 on 2025-10-08 07:21

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("slack", "0019_conversation_is_nest_bot_assistant_enabled"),
]

operations = [
migrations.AddField(
model_name="conversation",
name="conversation_context",
field=models.TextField(blank=True, verbose_name="Conversation context"),
),
]
41 changes: 41 additions & 0 deletions backend/apps/slack/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Meta:

# Additional attributes.
sync_messages = models.BooleanField(verbose_name="Sync messages", default=False)
conversation_context = models.TextField(blank=True, verbose_name="Conversation context")

def __str__(self):
"""Channel human readable representation."""
Expand Down Expand Up @@ -105,3 +106,43 @@ def update_data(conversation_data, workspace, *, save=True):
conversation.save()

return conversation

def add_to_context(self, user_message: str, bot_response: str | None = None) -> None:
"""Add messages to the conversation context.

Args:
user_message: The user's message to add to context.
bot_response: The bot's response to add to context.

"""
if not self.conversation_context:
self.conversation_context = ""

self.conversation_context = f"{self.conversation_context}{f'User: {user_message}\n'}"

if bot_response:
self.conversation_context = f"{self.conversation_context}{f'Bot: {bot_response}\n'}"

self.save(update_fields=["conversation_context"])

def get_context(self, conversation_context_limit: int | None = None) -> str:
"""Get the conversation context.

Args:
conversation_context_limit: Optional limit on number of exchanges to return.

Returns:
The conversation context, potentially limited to recent exchanges.

"""
if not self.conversation_context:
return ""

if conversation_context_limit is None:
return self.conversation_context

lines = self.conversation_context.strip().split("\n")
if len(lines) <= conversation_context_limit * 2:
return self.conversation_context

return "\n".join(lines[-(conversation_context_limit * 2) :])
29 changes: 17 additions & 12 deletions backend/tests/apps/slack/events/message_posted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,17 @@ def test_handle_event_ignores_thread_messages(self, message_handler):
client = Mock()

with patch("apps.slack.events.message_posted.Message") as mock_message:
mock_message.DoesNotExist = Exception
mock_message.objects.get.side_effect = Exception("Message not found")
mock_filter = Mock()
mock_message.objects.filter.return_value = mock_filter

message_handler.handle_event(event, client)

mock_message.objects.filter.assert_called_once_with(
slack_message_id=event.get("thread_ts"),
conversation__slack_channel_id=event.get("channel"),
)
mock_filter.update.assert_called_once_with(has_replies=True)

client.chat_postMessage.assert_not_called()

def test_handle_event_conversation_not_found(self, message_handler):
Expand Down Expand Up @@ -260,6 +266,7 @@ def test_handle_event_member_not_found(self, message_handler, conversation_mock)
patch("apps.slack.events.message_posted.Conversation") as mock_conversation,
patch("apps.slack.events.message_posted.Member") as mock_member,
patch("apps.slack.events.message_posted.Message") as mock_message_model,
patch("apps.slack.events.message_posted.django_rq") as mock_django_rq,
):
mock_conversation.objects.get.return_value = conversation_mock

Expand All @@ -273,19 +280,17 @@ def test_handle_event_member_not_found(self, message_handler, conversation_mock)
mock_message.id = 1
mock_message_model.update_data.return_value = mock_message

with (
patch.object(
message_handler.question_detector,
"is_owasp_question",
return_value=True,
),
patch("apps.slack.events.message_posted.django_rq") as mock_django_rq,
):
mock_queue = Mock()
mock_django_rq.get_queue.return_value = mock_queue
mock_queue = Mock()
mock_django_rq.get_queue.return_value = mock_queue

with patch.object(
message_handler.question_detector,
"is_owasp_question",
return_value=True,
):
message_handler.handle_event(event, client)

mock_member.update_data.assert_called_once()
mock_django_rq.get_queue.assert_called_once()

def test_handle_event_empty_text(self, message_handler):
Expand Down
Loading