Skip to content

Allow using Discord threads #3362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ however, insignificant breaking changes do not guarantee a major version bump, s

# v4.1.1

### Breaking
- Modmail threads are now potentially Discord threads

### Fixed
- `?msglink` now supports threads with multiple recipients. ([PR #3341](https://github.com/modmail-dev/Modmail/pull/3341))
- Fixed persistent notes not working due to discord.py internal change. ([PR #3324](https://github.com/modmail-dev/Modmail/pull/3324))
Expand Down
37 changes: 23 additions & 14 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,21 @@ def log_channel(self) -> typing.Optional[discord.TextChannel]:
logger.debug("LOG_CHANNEL_ID was invalid, removed.")
self.config.remove("log_channel_id")
if self.main_category is not None:
try:
channel = self.main_category.channels[0]
self.config["log_channel_id"] = channel.id
logger.warning("No log channel set, setting #%s to be the log channel.", channel.name)
return channel
except IndexError:
pass
if isinstance(self.main_category, discord.CategoryChannel):
try:
channel = self.main_category.channels[0]
self.config["log_channel_id"] = channel.id
logger.warning("No log channel set, setting #%s to be the log channel.", channel.name)
return channel
except IndexError:
pass
elif isinstance(self.main_category, discord.TextChannel):
self.config["log_channel_id"] = self.main_category.id
logger.warning(
"No log channel set, setting #%s to be the log channel.", self.main_category.name
)
return self.main_category

logger.warning(
"No log channel set, set one with `%ssetup` or `%sconfig set log_channel_id <id>`.",
self.prefix,
Expand Down Expand Up @@ -419,13 +427,13 @@ def using_multiple_server_setup(self) -> bool:
return self.modmail_guild != self.guild

@property
def main_category(self) -> typing.Optional[discord.CategoryChannel]:
def main_category(self) -> typing.Optional[discord.abc.GuildChannel]:
if self.modmail_guild is not None:
category_id = self.config["main_category_id"]
if category_id is not None:
try:
cat = discord.utils.get(self.modmail_guild.categories, id=int(category_id))
if cat is not None:
cat = discord.utils.get(self.modmail_guild.channels, id=int(category_id))
if cat is not None and isinstance(cat, (discord.CategoryChannel, discord.TextChannel)):
return cat
except ValueError:
pass
Expand Down Expand Up @@ -1351,11 +1359,12 @@ async def on_guild_channel_delete(self, channel):
if channel.guild != self.modmail_guild:
return

if self.main_category == channel:
logger.debug("Main category was deleted.")
self.config.remove("main_category_id")
await self.config.update()

if isinstance(channel, discord.CategoryChannel):
if self.main_category == channel:
logger.debug("Main category was deleted.")
self.config.remove("main_category_id")
await self.config.update()
return

if not isinstance(channel, discord.TextChannel):
Expand Down
9 changes: 9 additions & 0 deletions cogs/modmail.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ async def unsubscribe(self, ctx, *, user_or_role: Union[discord.Role, User, str.
@checks.thread_only()
async def nsfw(self, ctx):
"""Flags a Modmail thread as NSFW (not safe for work)."""
if isinstance(ctx.channel, discord.Thread):
await ctx.send("Unable to set NSFW status for Discord threads.")
return
await ctx.channel.edit(nsfw=True)
sent_emoji, _ = await self.bot.retrieve_emoji()
await self.bot.add_reaction(ctx.message, sent_emoji)
Expand All @@ -687,6 +690,9 @@ async def nsfw(self, ctx):
@checks.thread_only()
async def sfw(self, ctx):
"""Flags a Modmail thread as SFW (safe for work)."""
if isinstance(ctx.channel, discord.Thread):
await ctx.send("Unable to set NSFW status for Discord threads.")
return
await ctx.channel.edit(nsfw=False)
sent_emoji, _ = await self.bot.retrieve_emoji()
await self.bot.add_reaction(ctx.message, sent_emoji)
Expand Down Expand Up @@ -775,6 +781,9 @@ def format_log_embeds(self, logs, avatar_url):
@commands.cooldown(1, 600, BucketType.channel)
async def title(self, ctx, *, name: str):
"""Sets title for a thread"""
if isinstance(ctx.channel, discord.Thread):
await ctx.send("Unable to set titles for Discord threads.")
return
await ctx.thread.set_title(name)
sent_emoji, _ = await self.bot.retrieve_emoji()
await ctx.message.pin()
Expand Down
119 changes: 77 additions & 42 deletions core/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,17 @@ async def from_channel(cls, manager: "ThreadManager", channel: discord.TextChann

async def get_genesis_message(self) -> discord.Message:
if self._genesis_message is None:
async for m in self.channel.history(limit=5, oldest_first=True):
if m.author == self.bot.user:
if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles":
self._genesis_message = m
self._genesis_message = await self._get_genesis_message(self.channel, self.bot.user)

return self._genesis_message

@staticmethod
async def _get_genesis_message(channel, own_user) -> discord.Message | None:
async for m in channel.history(limit=5, oldest_first=True):
if m.author == own_user:
if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles":
return m

async def setup(self, *, creator=None, category=None, initial_message=None):
"""Create the thread channel and other io related initialisation tasks"""
self.bot.dispatch("thread_initiate", self, creator, category, initial_message)
Expand Down Expand Up @@ -294,6 +298,11 @@ async def activate_auto_triggers():
activate_auto_triggers(),
send_persistent_notes(),
)
if creator is not None:
# now that the genesis message is sent,
# we can cache things.
creator.cache[self.recipient.id] = self

self.bot.dispatch("thread_ready", self, creator, category, initial_message)

def _format_info_embed(self, user, log_url, log_count, color):
Expand Down Expand Up @@ -434,9 +443,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
self.channel.id,
{
"open": False,
"title": match_title(self.channel.topic),
"title": match_title(self.channel.topic)
if isinstance(self.channel, discord.TextChannel)
else None,
"closed_at": str(discord.utils.utcnow()),
"nsfw": self.channel.nsfw,
"nsfw": self.channel.nsfw if isinstance(self.channel, discord.TextChannel) else False,
"close_message": message,
"closer": {
"id": str(closer.id),
Expand Down Expand Up @@ -466,7 +477,7 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
else:
sneak_peak = "No content"

if self.channel.nsfw:
if isinstance(self.channel, discord.TextChannel) and self.channel.nsfw:
_nsfw = "NSFW-"
else:
_nsfw = ""
Expand Down Expand Up @@ -1230,39 +1241,39 @@ async def _update_users_genesis(self):
await genesis_message.edit(embed=embed)

async def add_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None:
topic = ""
title, _, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

topic += f"User ID: {self._id}"

self._other_recipients += users
self._other_recipients = list(set(self._other_recipients))
if isinstance(self.channel, discord.TextChannel):
topic = ""
title, _, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

ids = ",".join(str(i.id) for i in self._other_recipients)
topic += f"User ID: {self._id}"

topic += f"\nOther Recipients: {ids}"
ids = ",".join(str(i.id) for i in self._other_recipients)

await self.channel.edit(topic=topic)
topic += f"\nOther Recipients: {ids}"

await self.channel.edit(topic=topic)
await self._update_users_genesis()

async def remove_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None:
topic = ""
title, user_id, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

topic += f"User ID: {user_id}"

for u in users:
self._other_recipients.remove(u)
if isinstance(self.channel, discord.TextChannel):
topic = ""
title, user_id, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

if self._other_recipients:
ids = ",".join(str(i.id) for i in self._other_recipients)
topic += f"\nOther Recipients: {ids}"
topic += f"User ID: {user_id}"

await self.channel.edit(topic=topic)
if self._other_recipients:
ids = ",".join(str(i.id) for i in self._other_recipients)
topic += f"\nOther Recipients: {ids}"

await self.channel.edit(topic=topic)
await self._update_users_genesis()


Expand All @@ -1276,6 +1287,13 @@ def __init__(self, bot):
async def populate_cache(self) -> None:
for channel in self.bot.modmail_guild.text_channels:
await self.find(channel=channel)
for thread in self.bot.modmail_guild.threads:
await self.find(channel=thread)
# handle any threads archived while bot was offline (is this slow? yes. whatever....)
# (maybe this should only iterate until the archived_at timestamp is fine)
if isinstance(self.bot.main_category, discord.TextChannel):
async for thread in self.bot.main_category.archived_threads():
await self.find(channel=thread)

def __len__(self):
return len(self.cache)
Expand All @@ -1290,19 +1308,25 @@ async def find(
self,
*,
recipient: typing.Union[discord.Member, discord.User] = None,
channel: discord.TextChannel = None,
channel: discord.TextChannel | discord.Thread = None,
recipient_id: int = None,
) -> typing.Optional[Thread]:
"""Finds a thread from cache or from discord channel topics."""
if recipient is None and channel is not None and isinstance(channel, discord.TextChannel):
if (
recipient is None
and channel is not None
and isinstance(channel, (discord.TextChannel, discord.Thread))
):
# check cache *before* potentially awaiting
user_id, cache_thread = next(
((k, v) for k, v in self.cache.items() if v.channel == channel), (-1, None)
)

thread = await self._find_from_channel(channel)
if thread is None:
user_id, thread = next(
((k, v) for k, v in self.cache.items() if v.channel == channel), (-1, None)
)
if thread is not None:
logger.debug("Found thread with tempered ID.")
await channel.edit(topic=f"User ID: {user_id}")
if thread is None and cache_thread is not None:
logger.debug("Found thread with tampered ID.")
await channel.edit(topic=f"User ID: {user_id}")
thread = cache_thread
return thread

if recipient:
Expand Down Expand Up @@ -1357,10 +1381,23 @@ async def _find_from_channel(self, channel):
extracts user_id from that.
"""

if not channel.topic:
return None
if isinstance(channel, discord.Thread) or not channel.topic:
# actually check for genesis embed :)
msg = await Thread._get_genesis_message(channel, self.bot.user)
if not msg:
return None

_, user_id, other_ids = parse_channel_topic(channel.topic)
embed = msg.embeds[0]
user_id = int((embed.footer.text or "-1").removeprefix("User ID: ").split(" ", 1)[0])
other_ids = []
for field in embed.fields:
if field.name == "Other Recipients" and field.value:
other_ids = map(
lambda mention: int(mention.removeprefix("<@").removeprefix("!").removesuffix(">")),
field.value.split(" "),
)
else:
_, user_id, other_ids = parse_channel_topic(channel.topic)

if user_id == -1:
return None
Expand Down Expand Up @@ -1419,8 +1456,6 @@ async def create(

thread = Thread(self, recipient)

self.cache[recipient.id] = thread

if (message or not manual_trigger) and self.bot.config["confirm_thread_creation"]:
if not manual_trigger:
destination = recipient
Expand Down
18 changes: 11 additions & 7 deletions core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,17 @@ async def create_thread_channel(bot, recipient, category, overwrites, *, name=No
errors_raised = errors_raised or []

try:
channel = await bot.modmail_guild.create_text_channel(
name=name,
category=category,
overwrites=overwrites,
topic=f"User ID: {recipient.id}",
reason="Creating a thread channel.",
)
if isinstance(category, discord.TextChannel):
# we ignore `overwrites`... maybe make private threads so it's similar?
channel = await category.create_thread(name=name, reason="Creating a thread channel.", type=discord.ChannelType.public_thread)
else:
channel = await bot.modmail_guild.create_text_channel(
name=name,
category=category,
overwrites=overwrites,
topic=f"User ID: {recipient.id}",
reason="Creating a thread channel.",
)
except discord.HTTPException as e:
if (e.text, (category, name)) in errors_raised:
# Just raise the error to prevent infinite recursion after retrying
Expand Down