diff --git a/src/kernelbot/api/main.py b/src/kernelbot/api/main.py index 071a8219..808f5edf 100644 --- a/src/kernelbot/api/main.py +++ b/src/kernelbot/api/main.py @@ -426,6 +426,32 @@ async def _stream_submission_response( pass +@app.post("/admin/ban/{user_id}") +async def admin_ban_user( + user_id: str, + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), +) -> dict: + with db_context as db: + found = db.ban_user(user_id) + if not found: + raise HTTPException(status_code=404, detail=f"User {user_id} not found") + return {"status": "ok", "user_id": user_id, "banned": True} + + +@app.delete("/admin/ban/{user_id}") +async def admin_unban_user( + user_id: str, + _: Annotated[None, Depends(require_admin)], + db_context=Depends(get_db), +) -> dict: + with db_context as db: + found = db.unban_user(user_id) + if not found: + raise HTTPException(status_code=404, detail=f"User {user_id} not found") + return {"status": "ok", "user_id": user_id, "banned": False} + + @app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}") async def run_submission( # noqa: C901 leaderboard_name: str, diff --git a/src/kernelbot/cogs/admin_cog.py b/src/kernelbot/cogs/admin_cog.py index adce7616..8b21747a 100644 --- a/src/kernelbot/cogs/admin_cog.py +++ b/src/kernelbot/cogs/admin_cog.py @@ -123,6 +123,14 @@ def __init__(self, bot: "ClusterBot"): name="set-forum-ids", description="Sets forum IDs" )(self.set_forum_ids) + self.ban_user = bot.admin_group.command( + name="ban", description="Ban a user from making submissions" + )(self.ban_user) + + self.unban_user = bot.admin_group.command( + name="unban", description="Unban a user" + )(self.unban_user) + self.export_to_hf = bot.admin_group.command( name="export-hf", description="Export competition data to Hugging Face dataset" )(self.export_to_hf) @@ -154,6 +162,44 @@ async def is_creator_check( return True return False + @discord.app_commands.describe(user_id="Discord user ID to ban") + @with_error_handling + async def ban_user(self, interaction: discord.Interaction, user_id: str): + if not await self.admin_check(interaction): + await send_discord_message( + interaction, "You need to have Admin permissions to run this command", ephemeral=True + ) + return + + with self.bot.leaderboard_db as db: + if db.ban_user(user_id): + await send_discord_message( + interaction, f"User `{user_id}` has been banned.", ephemeral=True + ) + else: + await send_discord_message( + interaction, f"User `{user_id}` not found.", ephemeral=True + ) + + @discord.app_commands.describe(user_id="Discord user ID to unban") + @with_error_handling + async def unban_user(self, interaction: discord.Interaction, user_id: str): + if not await self.admin_check(interaction): + await send_discord_message( + interaction, "You need to have Admin permissions to run this command", ephemeral=True + ) + return + + with self.bot.leaderboard_db as db: + if db.unban_user(user_id): + await send_discord_message( + interaction, f"User `{user_id}` has been unbanned.", ephemeral=True + ) + else: + await send_discord_message( + interaction, f"User `{user_id}` not found.", ephemeral=True + ) + @discord.app_commands.describe( directory="Directory of the kernel definition. Also used as the leaderboard's name", gpu="The GPU to submit to. Leave empty for interactive selection/multiple GPUs", diff --git a/src/libkernelbot/leaderboard_db.py b/src/libkernelbot/leaderboard_db.py index c59bc271..650599af 100644 --- a/src/libkernelbot/leaderboard_db.py +++ b/src/libkernelbot/leaderboard_db.py @@ -1445,6 +1445,59 @@ def validate_cli_id(self, cli_id: str) -> Optional[dict[str, str]]: raise KernelBotError("Error validating CLI ID") from e + def ban_user(self, user_id: str) -> bool: + """Ban a user by their ID. Returns True if the user was found and banned.""" + try: + self.cursor.execute( + """ + UPDATE leaderboard.user_info + SET is_banned = TRUE + WHERE id = %s + """, + (str(user_id),), + ) + self.connection.commit() + return self.cursor.rowcount > 0 + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Error banning user %s", user_id, exc_info=e) + raise KernelBotError("Error banning user") from e + + def unban_user(self, user_id: str) -> bool: + """Unban a user by their ID. Returns True if the user was found and unbanned.""" + try: + self.cursor.execute( + """ + UPDATE leaderboard.user_info + SET is_banned = FALSE + WHERE id = %s + """, + (str(user_id),), + ) + self.connection.commit() + return self.cursor.rowcount > 0 + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Error unbanning user %s", user_id, exc_info=e) + raise KernelBotError("Error unbanning user") from e + + def is_user_banned(self, user_id: str) -> bool: + """Check if a user is banned.""" + try: + self.cursor.execute( + """ + SELECT is_banned FROM leaderboard.user_info + WHERE id = %s + """, + (str(user_id),), + ) + row = self.cursor.fetchone() + return row[0] if row else False + except psycopg2.Error as e: + self.connection.rollback() + logger.exception("Error checking ban status for user %s", user_id, exc_info=e) + raise KernelBotError("Error checking ban status") from e + def set_rate_limit(self, leaderboard_name: str, mode_category: str, max_per_hour: int) -> RateLimitItem: try: self.cursor.execute( diff --git a/src/libkernelbot/submission.py b/src/libkernelbot/submission.py index 60ecf4e6..69b83b24 100644 --- a/src/libkernelbot/submission.py +++ b/src/libkernelbot/submission.py @@ -49,6 +49,10 @@ def prepare_submission( # noqa: C901 "The bot is currently not accepting any new submissions, please try again later." ) + with backend.db as db: + if db.is_user_banned(str(req.user_id)): + raise KernelBotError("You are banned from making submissions.") + if profanity.contains_profanity(req.file_name): raise KernelBotError("Please provide a non-rude filename") diff --git a/src/migrations/20260318_01_ban-user.py b/src/migrations/20260318_01_ban-user.py new file mode 100644 index 00000000..c858a038 --- /dev/null +++ b/src/migrations/20260318_01_ban-user.py @@ -0,0 +1,22 @@ +""" +add_is_banned_to_user_info +""" + +from yoyo import step + +__depends__ = {'20260317_01_rate-limits'} + +steps = [ + step( + # forward + """ + ALTER TABLE leaderboard.user_info + ADD COLUMN is_banned BOOLEAN NOT NULL DEFAULT FALSE + """, + # backward + """ + ALTER TABLE leaderboard.user_info + DROP COLUMN is_banned; + """ + ) +] diff --git a/tests/test_submission.py b/tests/test_submission.py index e22fcb8e..f2bced05 100644 --- a/tests/test_submission.py +++ b/tests/test_submission.py @@ -31,6 +31,7 @@ def mock_backend(): "name": "test_board", } db_context.get_leaderboard_gpu_types.return_value = ["A100", "V100"] + db_context.is_user_banned.return_value = False return backend