diff --git a/bot/core/manage_mandatory_membership.py b/bot/core/manage_mandatory_membership.py index f921ab5..ac325b5 100644 --- a/bot/core/manage_mandatory_membership.py +++ b/bot/core/manage_mandatory_membership.py @@ -1,47 +1,126 @@ import json import logging -import mysql.connector -from redis import Redis from datetime import datetime +from typing import List, Optional, Dict +import mysql.connector +from redis import Redis +from bot.data.config import ADMIN from bot.db.database import Database class ManageMandatoryMembership: + """Manage list of mandatory channels with Redis and MySQL.""" + _REDIS_HASH = "mandatory_membership" - def __init__(self, db: Database, redis_client: Redis, root_logger: logging.Logger): + def __init__(self, db: Database, redis_client: Redis, root_logger: logging.Logger) -> None: self.db = db self.redis = redis_client self.log = root_logger - self._create_table() + self._ensure_table() - def channels(self): - pass + # ------------------------------------------------------------------ + def channels(self) -> List[int]: + """Return list of active channel IDs.""" + try: + redis_key = f"{self._REDIS_HASH}:all" + cached = self.redis.get(redis_key) + if cached: + return json.loads(cached) - def update(self, channel_id: int, initiator_user_id: int, is_active: bool): + self.db.cursor.execute( + "SELECT channel_id FROM channels WHERE is_active = TRUE AND deleted_at IS NULL" + ) + rows = self.db.cursor.fetchall() or [] + channel_ids = [row["channel_id"] for row in rows] + self.redis.set(redis_key, json.dumps(channel_ids)) + return channel_ids + except mysql.connector.Error as err: + self.log.error(f"MySQL error in channels(): {err}") + self.db.reconnect() + return [] + except Exception as err: + self.log.error(f"Error in channels(): {err}") + return [] - current_time = datetime.now() + # ------------------------------------------------------------------ + def update(self, channel_id: int, is_active: bool, updater_user_id: int) -> None: + """Insert or update a channel record. - redis_key = f"{self.ns}:channels" + The channel is created if it does not exist. Existing records can only be + updated by their creator or the super admin defined in :data:`ADMIN`. + All changes are synchronized between MySQL and Redis. + """ + try: + now = datetime.now() + redis_key = f"{self._REDIS_HASH}:{channel_id}" + cached = self.redis.get(redis_key) + channel: Optional[Dict] = json.loads(cached) if cached else None - redis_data = self.redis.get(redis_key) - - + if channel is None: + self.db.cursor.execute( + "SELECT * FROM channels WHERE channel_id=%s AND deleted_at IS NULL", + (channel_id,), + ) + channel = self.db.cursor.fetchone() + if channel: + self.redis.set(redis_key, json.dumps(channel)) - cache_data = { - "channel_id": channel_id, - "is_active": True, - "initiator_user_id": initiator_user_id, - "created_at": current_time.isoformat() - } - self.redis.set(redis_key, json.dumps(cache_data)) + if channel: + initiator_id = channel.get("initiator_user_id") if isinstance(channel, dict) else channel["initiator_user_id"] + if updater_user_id != ADMIN and updater_user_id != initiator_id: + self.log.warning( + "User %s is not allowed to update channel %s", updater_user_id, channel_id + ) + return + sql = ( + "UPDATE channels SET is_active=%s, updater_user_id=%s, updated_at=%s WHERE channel_id=%s" + ) + self.db.cursor.execute(sql, (is_active, updater_user_id, now, channel_id)) + self.db.connection.commit() + channel = dict(channel) + channel.update( + { + "is_active": is_active, + "updater_user_id": updater_user_id, + "updated_at": now.isoformat(), + } + ) + else: + sql = ( + "INSERT INTO channels (channel_id, initiator_user_id, updater_user_id, " + "is_active, created_at, updated_at) VALUES (%s, %s, %s, %s, %s, %s)" + ) + self.db.cursor.execute( + sql, + (channel_id, updater_user_id, updater_user_id, is_active, now, now), + ) + self.db.connection.commit() + channel = { + "channel_id": channel_id, + "initiator_user_id": updater_user_id, + "updater_user_id": updater_user_id, + "is_active": is_active, + "created_at": now.isoformat(), + "updated_at": now.isoformat(), + "deleted_at": None, + } + self.redis.set(redis_key, json.dumps(channel)) + self.redis.delete(f"{self._REDIS_HASH}:all") + except mysql.connector.Error as err: + self.log.error(f"MySQL error in update(): {err}") + self.db.reconnect() + except Exception as err: + self.log.error(f"Error in update(): {err}") - def _create_table(self): + # ------------------------------------------------------------------ + def _ensure_table(self) -> None: + """Ensure the ``channels`` table exists.""" try: sql = """ CREATE TABLE IF NOT EXISTS `channels` ( @@ -49,7 +128,7 @@ def _create_table(self): `channel_id` BIGINT, `initiator_user_id` BIGINT, `updater_user_id` BIGINT, - `is_active`BOOLEAN DEFAULT TRUE, + `is_active` BOOLEAN DEFAULT TRUE, `created_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` TIMESTAMP NULL DEFAULT NULL ON UPDATE CURRENT_TIMESTAMP, `deleted_at` TIMESTAMP NULL DEFAULT NULL @@ -59,6 +138,7 @@ def _create_table(self): self.db.connection.commit() except mysql.connector.Error as err: self.log.error(err) - self.db() + self.db.reconnect() except Exception as err: self.log.error(err) + diff --git a/bot/loader.py b/bot/loader.py index dd2e645..a7de697 100755 --- a/bot/loader.py +++ b/bot/loader.py @@ -11,6 +11,7 @@ from bot.core.feature_manager import FeatureManager from bot.core.settings_manager import SettingsManager from bot.core.admin_rights_manager import AdminsManager +from bot.core.manage_mandatory_membership import ManageMandatoryMembership @@ -45,6 +46,7 @@ FM = FeatureManager(db=db, root_logger=root_logger, redis_client=redis) AM = AdminsManager(db=db, redis_client=redis, root_logger=root_logger) SM = SettingsManager(db=db, redis_client=redis, root_logger=root_logger) +MMM = ManageMandatoryMembership(db=db, redis_client=redis, root_logger=root_logger) translator = Translator(db=db, FM=FM, root_logger=root_logger, redis_client=redis) diff --git a/tests/moduls/test_manage_mandatory_membership.py b/tests/moduls/test_manage_mandatory_membership.py new file mode 100644 index 0000000..0a1af71 --- /dev/null +++ b/tests/moduls/test_manage_mandatory_membership.py @@ -0,0 +1,26 @@ +import pytest + +from bot.loader import MMM, db, redis +from bot.data.config import ADMIN + +TEST_CHANNEL = 999999999 + +@pytest.fixture(autouse=True) +def cleanup(): + redis.delete(f"{MMM._REDIS_HASH}:{TEST_CHANNEL}") + redis.delete(f"{MMM._REDIS_HASH}:all") + db.cursor.execute("DELETE FROM channels WHERE channel_id=%s", (TEST_CHANNEL,)) + db.connection.commit() + yield + redis.delete(f"{MMM._REDIS_HASH}:{TEST_CHANNEL}") + redis.delete(f"{MMM._REDIS_HASH}:all") + db.cursor.execute("DELETE FROM channels WHERE channel_id=%s", (TEST_CHANNEL,)) + db.connection.commit() + + +def test_update_and_channels(): + MMM.update(TEST_CHANNEL, True, ADMIN) + assert redis.get(f"{MMM._REDIS_HASH}:{TEST_CHANNEL}") is not None + channels = MMM.channels() + assert TEST_CHANNEL in channels +