Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
318d0f3
Added latest_thread and daily_reminder to Thread table
chrisdedman Oct 10, 2025
1d1d210
added daily reminder option for recurrent threads
chrisdedman Oct 10, 2025
d79e102
Merge branch 'main' into feature/daily_reminder
chrisdedman Oct 10, 2025
70cb9da
Merge branch 'main' into feature/daily_reminder
chrisdedman Oct 14, 2025
e700ec7
Added ignore reminder that was archived or locked + skip if there is …
chrisdedman Oct 15, 2025
84d3d4c
Remove unused import from Thread model and add newline at EOF
chrisdedman Oct 15, 2025
b79d561
update grace-framework version for stagging version purposes [will be…
chrisdedman Oct 15, 2025
2526221
Added initial threads cog unit test
chrisdedman Oct 15, 2025
d85664c
Merge branch 'main' into feature/daily_reminder
chrisdedman Oct 15, 2025
8bf83b7
Add Alembic migration to add latest_thread (BigInteger) and daily_rem…
chrisdedman Oct 15, 2025
c263d19
fix ruff formatting for test
chrisdedman Oct 15, 2025
a58874e
Refactor Thread model to use Field annotations and simplify recurrenc…
chrisdedman Oct 15, 2025
8ebd90f
Refactor threads cog: pass daily_reminder to Thread.create, use Threa…
chrisdedman Oct 15, 2025
e6806da
Ran ruff for formatting
chrisdedman Oct 16, 2025
3b396de
Merge branch 'main' into feature/daily_reminder
chrisdedman Oct 16, 2025
2469760
Rename latest_thread to latest_thread_id in Thread model and Alembic …
chrisdedman Oct 19, 2025
2f14160
ran ruff format
chrisdedman Oct 19, 2025
04a54c3
simplify threads fetch
chrisdedman Oct 19, 2025
408ba59
refactor unit test to use the test db instead of patches + added unit…
chrisdedman Oct 19, 2025
1b30f3e
add test db
chrisdedman Oct 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 88 additions & 16 deletions bot/extensions/threads_cog.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import traceback
from logging import info
from pytz import timezone

from discord import Embed, Interaction, TextStyle
from discord.app_commands import Choice, autocomplete
from discord.ext.commands import Cog, Context, has_permissions, hybrid_group

from discord.ui import Modal, TextInput
from pytz import timezone
from discord.ext.commands import Cog, has_permissions, hybrid_group, Context

from bot.models.extensions.thread import Thread
from bot.classes.recurrence import Recurrence
from bot.extensions.command_error_handler import send_command_help
from bot.models.extensions.thread import Thread
from lib.config_required import cog_config_required


Expand All @@ -28,14 +29,20 @@ class ThreadModal(Modal, title="Thread"):
style=TextStyle.paragraph,
)

def __init__(self, recurrence: Recurrence, thread: Thread = None):
def __init__(
self,
recurrence: Recurrence,
reminder: bool = None,
thread: Thread = None,
):
super().__init__()

if thread:
self.thread_title.default = thread.title
self.thread_content.default = thread.content

self.thread = thread
self.thread_reminder = reminder
self.thread_recurrence = recurrence

async def on_submit(self, interaction: Interaction):
Expand All @@ -49,17 +56,20 @@ async def create_thread(self, interaction: Interaction):
title=self.thread_title.value,
content=self.thread_content.value,
recurrence=self.thread_recurrence,
daily_reminder=self.thread_reminder,
)

await interaction.response.send_message(
f"Thread __**{thread.id}**__ created!", ephemeral=True
)

async def update_thread(self, interaction: Interaction):
self.thread.title = (self.thread_title.value,)
self.thread.content = (self.thread_content.value,)
self.thread.recurrence = self.thread_recurrence

self.thread.save()
self.thread.update(
title=self.thread_title.value,
content=self.thread_content.value,
recurrence=self.thread_recurrence,
daily_reminder=self.thread_reminder,
)

await interaction.response.send_message(
f"Thread __**{self.thread.id}**__ updated!", ephemeral=True
Expand All @@ -72,7 +82,10 @@ async def on_error(self, interaction: Interaction, error: Exception):
traceback.print_exception(type(error), error, error.__traceback__)


async def thread_autocomplete(_: Interaction, current: str) -> list[Choice[str]]:
async def thread_autocomplete(
_: Interaction,
current: str,
) -> list[Choice[str]]:
return [
Choice(name=t.title, value=str(t.id))
for t in Thread.all()
Expand Down Expand Up @@ -120,6 +133,47 @@ def cog_load(self):
)
)

# Runs reminders everyday at 12:30
self.jobs.append(
self.bot.scheduler.add_job(
self.daily_reminder, "cron", hour=12, minute=30, timezone=self.timezone
)
)

async def daily_reminder(self):
"""Send a daily reminder for active threads."""
info("Posting daily threads's reminder")

embed = Embed(
color=self.bot.default_color,
title="🔔 Daily Reminder",
description="Join the discussion in the latest active threads:",
)

if threads := Thread.all():
for thread in threads:
if (
hasattr(thread, "latest_thread")
and thread.latest_thread
and thread.daily_reminder
):
discord_thread = await self.bot.fetch_channel(
int(thread.latest_thread)
)
if getattr(discord_thread, "archived", False) or getattr(
discord_thread, "locked", False
):
continue # Skip archieved and locked threads

embed.add_field(
name="", value=f"- <#{thread.latest_thread}>", inline=False
)

if embed.fields:
channel = self.bot.get_channel(self.threads_channel_id)
if channel:
await channel.send(embed=embed)

def cog_unload(self):
for job in self.jobs:
self.bot.scheduler.remove_job(job.id)
Expand Down Expand Up @@ -153,7 +207,11 @@ async def post_thread(self, thread: Thread):

if channel:
message = await channel.send(content=content, embed=embed)
await message.create_thread(name=thread.title)
discord_thread = await message.create_thread(name=thread.title)

if thread.daily_reminder:
thread.latest_thread = discord_thread.id
thread.save()

@hybrid_group(name="threads", help="Commands to manage threads")
@has_permissions(administrator=True)
Expand All @@ -170,7 +228,10 @@ async def list(self, ctx: Context):
for thread in threads:
embed.add_field(
name=f"[{thread.id}] {thread.title}",
value=f"**Recurrence**: {thread.recurrence}",
value=(
f"**Recurrence**: {thread.recurrence}\n"
f"**Reminder**: {thread.daily_reminder}"
),
inline=False,
)
else:
Expand All @@ -180,8 +241,13 @@ async def list(self, ctx: Context):

@threads_group.command(help="Creates a new thread")
@has_permissions(administrator=True)
async def create(self, ctx: Context, recurrence: Recurrence):
modal = ThreadModal(recurrence)
async def create(
self,
ctx: Context,
recurrence: Recurrence,
reminder: bool,
):
modal = ThreadModal(recurrence, reminder=reminder)
await ctx.interaction.response.send_modal(modal)

@threads_group.command(help="Deletes a given thread")
Expand All @@ -197,9 +263,15 @@ async def delete(self, ctx: Context, thread: int):
@threads_group.command(help="Update a thread")
@has_permissions(administrator=True)
@autocomplete(thread=thread_autocomplete)
async def update(self, ctx: Context, thread: int, recurrence: Recurrence):
async def update(
self,
ctx: Context,
thread: int,
recurrence: Recurrence,
reminder: bool,
):
if thread := Thread.find(thread):
modal = ThreadModal(recurrence, thread=thread)
modal = ThreadModal(recurrence, reminder=reminder, thread=thread)
await ctx.interaction.response.send_modal(modal)
else:
await ctx.send("Thread not found!", ephemeral=True)
Expand Down
6 changes: 5 additions & 1 deletion bot/models/extensions/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ class Thread(Model):
id: int | None = Field(default=None, primary_key=True)
title: str
content: str = Field(sa_type=Text)
recurrence: Recurrence = EnumField(Recurrence, default=Recurrence.NONE, nullable=False)
recurrence: Recurrence = EnumField(
Recurrence, default=Recurrence.NONE, nullable=False
)
latest_thread: int
daily_reminder: bool

@classmethod
def find_by_recurrence(cls, recurrence: Recurrence) -> "Recurrence":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""add_latest_thread_and_daily_reminder_to_threads

Revision ID: b7c695397ab2
Revises: cc8da39749e7
Create Date: 2025-10-15 15:58:03.467659

"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "b7c695397ab2"
down_revision = "cc8da39749e7"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column("threads", sa.Column("latest_thread", sa.BigInteger(), nullable=True))
op.add_column("threads", sa.Column("daily_reminder", sa.Boolean()))


def downgrade() -> None:
op.drop_column("threads", "latest_thread")
op.drop_column("threads", "daily_reminder")
150 changes: 150 additions & 0 deletions tests/extensions/test_threads_cog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

from bot.extensions.threads_cog import ThreadsCog
from unittest.mock import AsyncMock, MagicMock, patch


@pytest.fixture
def mock_bot():
"""Create a mock Discord bot instance."""
bot = MagicMock()
bot.default_color = 0xFFFFFF
bot.app.config.get = MagicMock(return_value=None)
bot.scheduler = MagicMock()
return bot


@pytest.fixture
def threads_cog(mock_bot):
"""Instantiate the ThreadsCog with a mock bot."""

return ThreadsCog(mock_bot)


@pytest.fixture
def dummy_modal(monkeypatch):
"""Fixture that patches ThreadModal and records constructor args."""
called_args = {}

class DummyModal:
def __init__(self, recurrence, reminder=None, thread=None):
called_args["recurrence"] = recurrence
called_args["reminder"] = reminder
called_args["thread"] = thread

monkeypatch.setattr("bot.extensions.threads_cog.ThreadModal", DummyModal)

return called_args


@pytest.mark.asyncio
async def test_create_thread_modal_called(threads_cog):
"""Verify that thread modal is called."""
ctx = MagicMock()
ctx.interaction = MagicMock()
ctx.interaction.response = MagicMock()
ctx.interaction.response.send_modal = AsyncMock()

recurrence = "DAILY"
reminder = True
await threads_cog.create.callback(
threads_cog,
ctx,
recurrence=recurrence,
reminder=reminder,
)

ctx.interaction.response.send_modal.assert_awaited_once()


@pytest.mark.asyncio
async def test_create_modal_args(threads_cog, dummy_modal):
ctx = MagicMock()
ctx.interaction = MagicMock()
ctx.interaction.response = MagicMock()
ctx.interaction.response.send_modal = AsyncMock()

recurrence = "DAILY"
reminder = True

await threads_cog.create.callback(
threads_cog,
ctx,
recurrence=recurrence,
reminder=reminder,
)

ctx.interaction.response.send_modal.assert_awaited_once()
assert dummy_modal["recurrence"] == recurrence
assert dummy_modal["reminder"] == reminder
assert dummy_modal["thread"] is None


@pytest.mark.asyncio
async def test_daily_reminder_no_threads(threads_cog, mock_bot):
"""Test daily_reminder when there are no threads."""
with patch("bot.extensions.threads_cog.Thread.all", return_value=[]):
mock_channel = MagicMock()
mock_bot.get_channel.return_value = mock_channel

await threads_cog.daily_reminder()
mock_channel.send.assert_not_called()


@pytest.mark.asyncio
async def test_daily_reminder_with_active_threads(threads_cog, mock_bot):
"""Test daily_reminder sends reminders for active threads."""
thread1 = MagicMock()
thread1.latest_thread = 123
thread1.daily_reminder = True
thread1.title = "Thread 1"
thread1.content = "Content 1"

thread2 = MagicMock()
thread2.latest_thread = 456
thread2.daily_reminder = False # Should not be included

# discord_thread is not archived or locked
discord_thread = MagicMock()
discord_thread.archived = False
discord_thread.locked = False

with patch(
"bot.extensions.threads_cog.Thread.all", return_value=[thread1, thread2]
):
mock_bot.fetch_channel = AsyncMock(return_value=discord_thread)
mock_channel = MagicMock()
mock_channel.send = AsyncMock()
mock_bot.get_channel.return_value = mock_channel

await threads_cog.daily_reminder()

mock_channel.send.assert_awaited_once()
args, kwargs = mock_channel.send.await_args
embed = kwargs.get("embed")
assert embed is not None
assert "Daily Reminder" in embed.title

assert len(embed.fields) == 1
assert embed.fields[0].value == f"- <#{thread1.latest_thread}>"
assert any(thread2.latest_thread != field.value for field in embed.fields)


@pytest.mark.asyncio
async def test_daily_reminder_skips_archived_and_locked(threads_cog, mock_bot):
"""Test daily_reminder skips archived and locked threads."""
thread = MagicMock()
thread.latest_thread = 789
thread.daily_reminder = True

discord_thread = MagicMock()
discord_thread.archived = True
discord_thread.locked = True

with patch("bot.extensions.threads_cog.Thread.all", return_value=[thread]):
mock_bot.fetch_channel = AsyncMock(return_value=discord_thread)
mock_channel = MagicMock()
mock_bot.get_channel.return_value = mock_channel

await threads_cog.daily_reminder()
mock_channel.send.assert_not_called()