diff --git a/bot/extensions/threads_cog.py b/bot/extensions/threads_cog.py index 4398ac1f..389e931d 100644 --- a/bot/extensions/threads_cog.py +++ b/bot/extensions/threads_cog.py @@ -1,15 +1,16 @@ import traceback from logging import info +from pytz import timezone from discord import Embed, Interaction, TextStyle from discord.app_commands import Choice, autocomplete -from discord.ext.commands import Cog, Context, has_permissions, hybrid_group + from discord.ui import Modal, TextInput -from pytz import timezone +from discord.ext.commands import Cog, has_permissions, hybrid_group, Context +from bot.models.extensions.thread import Thread from bot.classes.recurrence import Recurrence from bot.extensions.command_error_handler import send_command_help -from bot.models.extensions.thread import Thread from lib.config_required import cog_config_required @@ -28,7 +29,12 @@ class ThreadModal(Modal, title="Thread"): style=TextStyle.paragraph, ) - def __init__(self, recurrence: Recurrence, thread: Thread = None): + def __init__( + self, + recurrence: Recurrence, + reminder: bool = None, + thread: Thread = None, + ): super().__init__() if thread: @@ -36,6 +42,7 @@ def __init__(self, recurrence: Recurrence, thread: Thread = None): self.thread_content.default = thread.content self.thread = thread + self.thread_reminder = reminder self.thread_recurrence = recurrence async def on_submit(self, interaction: Interaction): @@ -49,17 +56,20 @@ async def create_thread(self, interaction: Interaction): title=self.thread_title.value, content=self.thread_content.value, recurrence=self.thread_recurrence, + daily_reminder=self.thread_reminder, ) + await interaction.response.send_message( f"Thread __**{thread.id}**__ created!", ephemeral=True ) async def update_thread(self, interaction: Interaction): - self.thread.title = (self.thread_title.value,) - self.thread.content = (self.thread_content.value,) - self.thread.recurrence = self.thread_recurrence - - self.thread.save() + self.thread.update( + title=self.thread_title.value, + content=self.thread_content.value, + recurrence=self.thread_recurrence, + daily_reminder=self.thread_reminder, + ) await interaction.response.send_message( f"Thread __**{self.thread.id}**__ updated!", ephemeral=True @@ -72,7 +82,10 @@ async def on_error(self, interaction: Interaction, error: Exception): traceback.print_exception(type(error), error, error.__traceback__) -async def thread_autocomplete(_: Interaction, current: str) -> list[Choice[str]]: +async def thread_autocomplete( + _: Interaction, + current: str, +) -> list[Choice[str]]: return [ Choice(name=t.title, value=str(t.id)) for t in Thread.all() @@ -120,6 +133,42 @@ def cog_load(self): ) ) + # Runs reminders everyday at 12:30 + self.jobs.append( + self.bot.scheduler.add_job( + self.daily_reminder, "cron", hour=12, minute=30, timezone=self.timezone + ) + ) + + async def daily_reminder(self): + """Send a daily reminder for active threads.""" + info("Posting daily threads's reminder") + + embed = Embed( + color=self.bot.default_color, + title="🔔 Daily Reminder", + description="Join the discussion in the latest active threads:", + ) + + if threads := Thread.where( + Thread.latest_thread_id.isnot(None), daily_reminder=True + ).all(): + for thread in threads: + discord_thread = await self.bot.fetch_channel( + int(thread.latest_thread_id) + ) + if discord_thread.archived or discord_thread.locked: + continue # Skip archieved and locked threads + + embed.add_field( + name="", value=f"- <#{thread.latest_thread_id}>", inline=False + ) + + if embed.fields: + channel = self.bot.get_channel(self.threads_channel_id) + if channel: + await channel.send(embed=embed) + def cog_unload(self): for job in self.jobs: self.bot.scheduler.remove_job(job.id) @@ -153,7 +202,10 @@ async def post_thread(self, thread: Thread): if channel: message = await channel.send(content=content, embed=embed) - await message.create_thread(name=thread.title) + discord_thread = await message.create_thread(name=thread.title) + + if thread.daily_reminder: + thread.update(latest_thread_id=discord_thread.id) @hybrid_group(name="threads", help="Commands to manage threads") @has_permissions(administrator=True) @@ -170,7 +222,10 @@ async def list(self, ctx: Context): for thread in threads: embed.add_field( name=f"[{thread.id}] {thread.title}", - value=f"**Recurrence**: {thread.recurrence}", + value=( + f"**Recurrence**: {thread.recurrence}\n" + f"**Reminder**: {thread.daily_reminder}" + ), inline=False, ) else: @@ -180,8 +235,13 @@ async def list(self, ctx: Context): @threads_group.command(help="Creates a new thread") @has_permissions(administrator=True) - async def create(self, ctx: Context, recurrence: Recurrence): - modal = ThreadModal(recurrence) + async def create( + self, + ctx: Context, + recurrence: Recurrence, + reminder: bool, + ): + modal = ThreadModal(recurrence, reminder=reminder) await ctx.interaction.response.send_modal(modal) @threads_group.command(help="Deletes a given thread") @@ -197,9 +257,15 @@ async def delete(self, ctx: Context, thread: int): @threads_group.command(help="Update a thread") @has_permissions(administrator=True) @autocomplete(thread=thread_autocomplete) - async def update(self, ctx: Context, thread: int, recurrence: Recurrence): + async def update( + self, + ctx: Context, + thread: int, + recurrence: Recurrence, + reminder: bool, + ): if thread := Thread.find(thread): - modal = ThreadModal(recurrence, thread=thread) + modal = ThreadModal(recurrence, reminder=reminder, thread=thread) await ctx.interaction.response.send_modal(modal) else: await ctx.send("Thread not found!", ephemeral=True) diff --git a/bot/models/channel.py b/bot/models/channel.py index 4b88c23f..1cbecb84 100644 --- a/bot/models/channel.py +++ b/bot/models/channel.py @@ -10,4 +10,4 @@ class Channel(Model): ) channel_name: str = Field(primary_key=True) - channel_id: int = Field(primary_key=True) \ No newline at end of file + channel_id: int = Field(primary_key=True) diff --git a/bot/models/extensions/thread.py b/bot/models/extensions/thread.py index d5caa26c..94e57a81 100644 --- a/bot/models/extensions/thread.py +++ b/bot/models/extensions/thread.py @@ -11,7 +11,11 @@ class Thread(Model): id: int | None = Field(default=None, primary_key=True) title: str content: str = Field(sa_type=Text) - recurrence: Recurrence = EnumField(Recurrence, default=Recurrence.NONE, nullable=False) + recurrence: Recurrence = EnumField( + Recurrence, default=Recurrence.NONE, nullable=False + ) + latest_thread_id: int + daily_reminder: bool @classmethod def find_by_recurrence(cls, recurrence: Recurrence) -> "Recurrence": diff --git a/db/alembic/versions/b7c695397ab2_add_latest_thread_and_daily_reminder_to_.py b/db/alembic/versions/b7c695397ab2_add_latest_thread_and_daily_reminder_to_.py new file mode 100644 index 00000000..6f5f443d --- /dev/null +++ b/db/alembic/versions/b7c695397ab2_add_latest_thread_and_daily_reminder_to_.py @@ -0,0 +1,29 @@ +"""add_latest_thread_and_daily_reminder_to_threads + +Revision ID: b7c695397ab2 +Revises: cc8da39749e7 +Create Date: 2025-10-15 15:58:03.467659 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "b7c695397ab2" +down_revision = "cc8da39749e7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "threads", sa.Column("latest_thread_id", sa.BigInteger(), nullable=True) + ) + op.add_column("threads", sa.Column("daily_reminder", sa.Boolean())) + + +def downgrade() -> None: + op.drop_column("threads", "latest_thread_id") + op.drop_column("threads", "daily_reminder") diff --git a/tests/extensions/__init__.py b/tests/extensions/__init__.py index e69de29b..83527638 100644 --- a/tests/extensions/__init__.py +++ b/tests/extensions/__init__.py @@ -0,0 +1,10 @@ +from bot import app +from grace.database import up_migration + +app.load("test") + +app.drop_tables() +app.drop_database() + +app.create_database() +up_migration(app, "head") diff --git a/tests/extensions/test_threads_cog.py b/tests/extensions/test_threads_cog.py new file mode 100644 index 00000000..c9e681df --- /dev/null +++ b/tests/extensions/test_threads_cog.py @@ -0,0 +1,202 @@ +import pytest + +from bot.extensions.threads_cog import ThreadsCog +from unittest.mock import AsyncMock, MagicMock +from bot.models.extensions.thread import Thread + + +@pytest.fixture +def mock_bot(): + """Create a mock Discord bot instance.""" + bot = MagicMock() + bot.default_color = 0xFFFFFF + bot.app.config.get = MagicMock(return_value=None) + bot.scheduler = MagicMock() + return bot + + +@pytest.fixture +def threads_cog(mock_bot): + """Instantiate the ThreadsCog with a mock bot.""" + + return ThreadsCog(mock_bot) + + +@pytest.fixture +def dummy_modal(monkeypatch): + """Fixture that patches ThreadModal and records constructor args.""" + called_args = {} + + class DummyModal: + def __init__(self, recurrence, reminder=None, thread=None): + called_args["recurrence"] = recurrence + called_args["reminder"] = reminder + called_args["thread"] = thread + + monkeypatch.setattr("bot.extensions.threads_cog.ThreadModal", DummyModal) + + return called_args + + +@pytest.mark.asyncio +async def test_create_thread_modal_called(threads_cog): + """Verify that thread modal is called.""" + ctx = MagicMock() + ctx.interaction = MagicMock() + ctx.interaction.response = MagicMock() + ctx.interaction.response.send_modal = AsyncMock() + + recurrence = "DAILY" + reminder = True + await threads_cog.create.callback( + threads_cog, + ctx, + recurrence=recurrence, + reminder=reminder, + ) + + ctx.interaction.response.send_modal.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_modal_args(threads_cog, dummy_modal): + ctx = MagicMock() + ctx.interaction = MagicMock() + ctx.interaction.response = MagicMock() + ctx.interaction.response.send_modal = AsyncMock() + + recurrence = "DAILY" + reminder = True + + await threads_cog.create.callback( + threads_cog, + ctx, + recurrence=recurrence, + reminder=reminder, + ) + + ctx.interaction.response.send_modal.assert_awaited_once() + assert dummy_modal["recurrence"] == recurrence + assert dummy_modal["reminder"] == reminder + assert dummy_modal["thread"] is None + + +@pytest.mark.asyncio +async def test_daily_reminder_no_threads(threads_cog, mock_bot): + """Test daily_reminder when there are no threads.""" + _ = Thread.create( + title="Thread 1", + content="Content 1", + recurrence=None, + daily_reminder=False, # Should not be included + latest_thread_id=123, + ) + + mock_channel = MagicMock() + mock_channel.send = AsyncMock() + mock_bot.get_channel.return_value = mock_channel + await threads_cog.daily_reminder() + mock_channel.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_daily_reminder_with_active_threads(threads_cog, mock_bot): + """Test daily_reminder sends reminders for active threads.""" + thread1 = Thread.create( + title="Thread 1", + content="Content 1", + recurrence=None, + daily_reminder=True, + latest_thread_id=123, + ) + + thread2 = Thread.create( + title="Thread 2", + content="Content 2", + recurrence=None, + daily_reminder=False, # Should not be included + latest_thread_id=456, + ) + + discord_thread = MagicMock() + discord_thread.archived = False + discord_thread.locked = False + + mock_channel = MagicMock() + mock_channel.send = AsyncMock() + + mock_bot.get_channel.return_value = mock_channel + mock_bot.fetch_channel = AsyncMock(return_value=discord_thread) + + await threads_cog.daily_reminder() + mock_channel.send.assert_awaited_once() + + args, kwargs = mock_channel.send.await_args + embed = kwargs.get("embed") + assert embed is not None + + assert "Daily Reminder" in embed.title + assert len(embed.fields) == 1 + assert embed.fields[0].value == f"- <#{thread1.latest_thread_id}>" + assert any(thread2.latest_thread_id != field.value for field in embed.fields) + + +@pytest.mark.asyncio +async def test_daily_reminder_skips_archived_and_locked(threads_cog, mock_bot): + """Test daily_reminder skips archived and locked threads.""" + discord_thread = MagicMock() + discord_thread.archived = True + discord_thread.locked = True + + mock_channel = MagicMock() + mock_channel.send = AsyncMock() + mock_bot.get_channel.return_value = mock_channel + mock_bot.fetch_channel = AsyncMock(return_value=discord_thread) + + await threads_cog.daily_reminder() + mock_channel.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_post_thread(threads_cog, mock_bot): + """Test post_thread creates a thread in the specified channel.""" + thread = Thread.create( + title="Test Thread", + content="This is a test thread.", + recurrence=None, + daily_reminder=False, + latest_thread_id=None, + ) + + mock_channel = MagicMock() + mock_channel.send = AsyncMock() + mock_bot.get_channel.return_value = mock_channel + + message_mock = MagicMock() + message_mock.create_thread = AsyncMock() + mock_channel.send.return_value = message_mock + + await threads_cog.post_thread(thread) + + mock_channel.send.assert_awaited_once() + args, kwargs = mock_channel.send.await_args + embed = kwargs.get("embed") + assert embed is not None + assert embed.title == thread.title + assert embed.description == thread.content + + message_mock.create_thread.assert_awaited_once_with(name=thread.title) + + +@pytest.mark.asyncio +async def test_cog_unload_removes_jobs(threads_cog, mock_bot): + """Test that cog_unload removes scheduled jobs.""" + job1 = MagicMock() + job2 = MagicMock() + threads_cog.jobs = [job1, job2] + + threads_cog.cog_unload() + + mock_bot.scheduler.remove_job.assert_any_call(job1.id) + mock_bot.scheduler.remove_job.assert_any_call(job2.id) + assert mock_bot.scheduler.remove_job.call_count == 2