Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import at.hannibal2.skyhanni.discord.command.PullRequestCommand
import at.hannibal2.skyhanni.discord.command.ServerCommands
import at.hannibal2.skyhanni.discord.command.TagCommands
import at.hannibal2.skyhanni.discord.command.TagUndo
import net.dv8tion.jda.api.entities.channel.ChannelType
import net.dv8tion.jda.api.events.message.MessageReceivedEvent
import org.reflections.Reflections
import java.lang.reflect.Modifier
Expand Down Expand Up @@ -62,7 +63,7 @@ object CommandListener {
return
}

if (!command.userCommand) {
if (!command.userCommand && channelType != ChannelType.GUILD_PUBLIC_THREAD) {
if (!hasAdminPermissions()) {
reply("No permissions $PLEADING_FACE")
return
Expand All @@ -74,7 +75,7 @@ object CommandListener {
}
}

// allows to use `!<command> -help` instaed of `!help -<command>`
// allows to use `!<command> -help` instead of `!help -<command>`
if (args.size == 1 && args.first() == "-help") {
with(HelpCommand) {
sendUsageReply(literal)
Expand Down
51 changes: 51 additions & 0 deletions src/main/kotlin/at/hannibal2/skyhanni/discord/Database.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import java.sql.Connection
import java.sql.DriverManager

data class Tag(val keyword: String, var response: String, var uses: Int)
data class Link(val channel: String, val pr: Int)

object Database {
private val connection: Connection = DriverManager.getConnection("jdbc:sqlite:bot.db")
private val tags = mutableMapOf<String, Tag>()
private val links = mutableMapOf<String, Link>()

init {
val statement = connection.createStatement()
Expand All @@ -19,7 +21,16 @@ object Database {
append("response TEXT, ")
append("count INTEGER DEFAULT 0)")
})

statement.execute(buildString {
append("CREATE TABLE IF NOT EXISTS linkedposts (")
append("id INTEGER PRIMARY KEY AUTOINCREMENT, ")
append("channelid STRING UNIQUE, ")
append("pullrequestid INTEGER UNIQUE)")
})

loadTagCache()
loadLinkCache()
}

private fun loadTagCache() {
Expand All @@ -34,6 +45,17 @@ object Database {
resultSet.close()
}

private fun loadLinkCache() {
val statement = connection.prepareStatement("SELECT channelid, pullrequestid FROM linkedposts")
val resultSet = statement.executeQuery()
while (resultSet.next()) {
val channel = resultSet.getString("channelid")
val pr = resultSet.getInt("pullrequestid")
links[channel] = Link(channel, pr)
}
resultSet.close()
}

private fun ensureCountColumnExists() {
val statement = connection.prepareStatement("PRAGMA table_info(keywords)")
val resultSet = statement.executeQuery()
Expand Down Expand Up @@ -66,6 +88,19 @@ object Database {
return updated
}

fun addLink(channel: String, pr: Int): Boolean {
val statement = connection.prepareStatement(
"INSERT OR REPLACE INTO linkedposts (channelid, pullrequestid) VALUES (?, ?)"
)
statement.setString(1, channel)
statement.setInt(2, pr)
val updated = statement.executeUpdate() > 0
if (updated) {
links[channel] = Link(channel, pr)
}
return updated
}

fun getResponse(keyword: String, increment: Boolean = false): String? {
val key = keyword.lowercase()
val kObj = tags[key] ?: return null
Expand All @@ -81,6 +116,10 @@ object Database {
return kObj.response
}

fun getLink(prNumber: Int): Link? = links.values.find { it.pr == prNumber }

fun getPullrequest(channel: String): Int? = links[channel]?.pr

fun deleteTag(keyword: String): Boolean {
val key = keyword.lowercase()
val statement = connection.prepareStatement("DELETE FROM keywords WHERE keyword = ?")
Expand All @@ -90,11 +129,23 @@ object Database {
return updated
}

fun deleteLink(channel: String): Boolean {
val statement = connection.prepareStatement("DELETE FROM linkedposts WHERE channelid = ?")
statement.setString(1, channel)
val updated = statement.executeUpdate() > 0
if (updated) links.remove(channel)
return updated
}

fun listTags(): List<Tag> = tags.values.toList()

fun listLinks(): List<Link> = links.values.toList()

fun getTagCount(keyword: String): Int? {
return tags[keyword.lowercase()]?.uses
}

fun containsKeyword(keyword: String): Boolean = tags.containsKey(keyword.lowercase())

fun isLinked(channel: String): Boolean = links.containsKey(channel)
}
2 changes: 2 additions & 0 deletions src/main/kotlin/at/hannibal2/skyhanni/discord/DiscordBot.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const val PLEADING_FACE = "🥺"
const val BIG_X = "❌"
const val CHECK_MARK = "✅"
const val PING_HANNIBAL = "<@239858538959077376>"
const val OPEN_PR_TAG = "1350893914768277624"

fun main() {
val bot = startBot()
Expand Down Expand Up @@ -75,6 +76,7 @@ private fun startBot(): DiscordBot {
val messageListener = object : ListenerAdapter() {
override fun onMessageReceived(event: MessageReceivedEvent) {
CommandListener.onMessage(bot, event)
LinkListener.onMessage(bot, event)
}
}
jda.addEventListener(messageListener)
Expand Down
38 changes: 38 additions & 0 deletions src/main/kotlin/at/hannibal2/skyhanni/discord/LinkListener.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package at.hannibal2.skyhanni.discord

import at.hannibal2.skyhanni.discord.Utils.messageSend
import net.dv8tion.jda.api.events.message.MessageReceivedEvent

object LinkListener {

private const val GITHUB_WEBHOOK_ID = "1347997547368550461"
private val githubPattern =
"\\[ILike2WatchMemes/DiscordBot] (New comment on pull request|Pull request review submitted:) #(?<pr>\\d+):? .+".toPattern()

fun onMessage(bot: DiscordBot, event: MessageReceivedEvent) {
event.onMessage(bot)
}

private fun MessageReceivedEvent.onMessage(bot: DiscordBot) {
if (this.author.id == GITHUB_WEBHOOK_ID) {
val embed = this.message.embeds[0]

val title = embed.title ?: return

val prNumber = getPr(title)?.toInt() ?: return
val link = Database.getLink(prNumber) ?: return

val guild = bot.jda.getGuildById(BOT.config.allowedServerId) ?: return
val channel = guild.getThreadChannelById(link.channel) ?: return


channel.messageSend(this.message.embeds[0])
}
}

private fun getPr(title: String): String? {
val matcher = githubPattern.matcher(title)
if (!matcher.matches()) return null
return matcher.group("pr")
}
}
13 changes: 12 additions & 1 deletion src/main/kotlin/at/hannibal2/skyhanni/discord/Utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import net.dv8tion.jda.api.entities.channel.middleman.MessageChannel
import net.dv8tion.jda.api.entities.channel.unions.MessageChannelUnion
import net.dv8tion.jda.api.events.message.MessageReceivedEvent
import net.dv8tion.jda.api.utils.FileUpload
import org.slf4j.LoggerFactory
import java.awt.Color
import java.awt.Toolkit.getDefaultToolkit
import java.io.File
Expand All @@ -34,6 +33,10 @@ object Utils {
logAction("Error: $text")
}

fun MessageReceivedEvent.userSuccess(text: String) {
message.messageReply("✅ $text")
}

fun MessageReceivedEvent.reply(embed: MessageEmbed) {
message.messageReply(embed)
}
Expand Down Expand Up @@ -69,6 +72,14 @@ object Utils {
}
}

fun MessageChannel.messageSend(embed: MessageEmbed, instantly: Boolean = false) {
if (instantly) {
sendMessageEmbeds(embed).complete()
} else {
sendMessageEmbeds(embed).queue()
}
}

fun Message.replyWithConsumer(text: String, consumer: (MessageReceivedEvent) -> Unit) {
BotMessageHandler.log(text, consumer)
messageReply(text)
Expand Down
106 changes: 106 additions & 0 deletions src/main/kotlin/at/hannibal2/skyhanni/discord/command/LinkCommands.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package at.hannibal2.skyhanni.discord.command

import at.hannibal2.skyhanni.discord.Database
import at.hannibal2.skyhanni.discord.OPEN_PR_TAG
import at.hannibal2.skyhanni.discord.Option
import at.hannibal2.skyhanni.discord.PLEADING_FACE
import at.hannibal2.skyhanni.discord.Utils.logAction
import at.hannibal2.skyhanni.discord.Utils.reply
import at.hannibal2.skyhanni.discord.Utils.userError
import at.hannibal2.skyhanni.discord.Utils.userSuccess
import at.hannibal2.skyhanni.discord.command.LinkCommand.setTags
import at.hannibal2.skyhanni.discord.command.LinkCommand.setTitle
import net.dv8tion.jda.api.entities.channel.ChannelType
import net.dv8tion.jda.api.entities.channel.forums.ForumTag
import net.dv8tion.jda.api.events.message.MessageReceivedEvent
import net.dv8tion.jda.api.managers.channel.concrete.ThreadChannelManager

object LinkCommand : BaseCommand() {
override val name = "link"

override val description = "Link a forum post to a pull request."
override val options: List<Option> = listOf(
Option("number", "Number of the pull request you want the post to be linked to.")
)

override fun MessageReceivedEvent.execute(args: List<String>) {
if (args.size != 1) return wrongUsage("<number>")
val first = args.first()
val prNumber = first.toIntOrNull() ?: run {
userError("Unknown number $PLEADING_FACE ($first})")
return
}
if (prNumber < 1) {
userError("PR number needs to be positive $PLEADING_FACE")
return
}

if (!isFromType(ChannelType.GUILD_PUBLIC_THREAD)) {
userError("Wrong channel $PLEADING_FACE")
return
}

val post = channel.asThreadChannel()
val manager = post.manager

if (Database.isLinked(post.id)) {
reply("Post already linked to ${Database.getPullrequest(channel.id)} $PLEADING_FACE")
return
}

Database.addLink(post.id, prNumber)
logAction("${author.name} linked pr $prNumber")

if (!post.name.contains("(PR #")) manager.setTitle("${post.name} (PR #$prNumber)")

val tags = post.appliedTags
if (tags.none { it.id == OPEN_PR_TAG }) {
val tag = post.parentChannel.asForumChannel().getAvailableTagById(OPEN_PR_TAG) ?: return
manager.setTags(tags + tag)
}

userSuccess("Successfully linked PR $prNumber to this post.")
}

fun ThreadChannelManager.setTags(tags: List<ForumTag>) {
setAppliedTags(tags).queue()
}

fun ThreadChannelManager.setTitle(name: String) {
setName(name).queue()
}
}

object UnlinkCommand : BaseCommand() {
override val name = "unlink"

override val description = "Unlink a forum post from a pull request."

override fun MessageReceivedEvent.execute(args: List<String>) {
if (!isFromType(ChannelType.GUILD_PUBLIC_THREAD)) {
userError("Wrong channel $PLEADING_FACE")
return
}

if (!Database.isLinked(channel.id)) {
userError("Post isn't linked to any pull request $PLEADING_FACE")
return
}

val post = channel.asThreadChannel()
val manager = post.manager

Database.deleteLink(post.id)
logAction("${author.name} unlinked the pull request")

if (post.name.contains("(PR #")) manager.setTitle(post.name.split("(PR #")[0])

val tags = post.appliedTags
if (tags.any { it.id == OPEN_PR_TAG }) {
val tag = post.parentChannel.asForumChannel().getAvailableTagById(OPEN_PR_TAG) ?: return
manager.setTags(tags.filter { it != tag })
}

userSuccess("Successfully unlinked this post.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object PullRequestCommand : BaseCommand() {
append("\n")
}

var inBeta: Boolean = false
var inBeta = false

val labels = pr.labels.map { it.name }.toSet()

Expand Down Expand Up @@ -226,7 +226,7 @@ object PullRequestCommand : BaseCommand() {
)

private fun appendLabelCategory(labelType: String, labels: Set<String>, stringBuilder: StringBuilder, suffix: String = ""): StringBuilder {
val labelsWithType = labels.intersect(labelTypes[labelType] ?: setOf())
val labelsWithType = labels.intersect((labelTypes[labelType] ?: setOf()).toSet())
if (labelsWithType.isEmpty()) return stringBuilder.append(if (suffix.isNotEmpty()) "> $labelType: $suffix\n" else "")
return stringBuilder.append("> $labelType: `${labelsWithType.joinToString("` `")}`$suffix\n")
}
Expand Down Expand Up @@ -258,7 +258,7 @@ object PullRequestCommand : BaseCommand() {
return
}
val prNumber = args[1].toLongOrNull() ?: run {
reply("unknwon number $PLEADING_FACE (${args[1]})")
reply("unknown number $PLEADING_FACE (${args[1]})")
return
}

Expand Down Expand Up @@ -288,7 +288,7 @@ object PullRequestCommand : BaseCommand() {
val (_, downloadTime) = timeExecution {
github.downloadArtifact(artifactId, fileRaw)
}
reply("artifact downnloaded in ${downloadTime.format()}")
reply("artifact downloaded in ${downloadTime.format()}")

Utils.unzipFile(fileRaw, fileUnzipped)
fileRaw.delete()
Expand Down Expand Up @@ -324,4 +324,4 @@ object PullRequestCommand : BaseCommand() {
return true
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object TagCommands {

fun handleTag(event: MessageReceivedEvent): Boolean {
val message = event.message
var keyword = message.contentRaw.substring(1)
var keyword = message.contentRaw.substring(1).trimStart()
var deleting = false
if (keyword.endsWith(" -d")) {
keyword = keyword.dropLast(3)
Expand Down