Skip to content

Commit

Permalink
Using csv files to store the data. Fixes #5
Browse files Browse the repository at this point in the history
  • Loading branch information
Era Dorta committed Apr 3, 2022
1 parent 7e04213 commit 3c986dc
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 35 deletions.
24 changes: 6 additions & 18 deletions signalblast/bot_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@


class BotAnswers():
subscribers_data_path = get_code_data_path() / 'subscribers.txt'
banned_users_data_path = get_code_data_path() / 'banned_users.txt'
subscribers_data_path = get_code_data_path() / 'subscribers.csv'
banned_users_data_path = get_code_data_path() / 'banned_users.csv'

subscribers_phone_data_path = get_code_data_path() / 'subscribers_phone.txt'
banned_users_phone_data_path = get_code_data_path() / 'banned_users_phone.txt'

def __init__(self) -> None:
self.subscribers: Users = None
Expand All @@ -30,8 +28,6 @@ def __init__(self) -> None:
self.must_subscribe_message: str = None
self.logger: Logger = None
self.expiration_time = None
self.subscribers_phone: Users = None
self.banned_users_phone: Users = None
self.ping_job: Job = None

@classmethod
Expand All @@ -40,9 +36,6 @@ async def create(cls, logger: Logger, admin_pass: Optional[str], expiration_time
self.subscribers = await Users.load_from_file(self.subscribers_data_path)
self.banned_users = await Users.load_from_file(self.banned_users_data_path)

self.subscribers_phone = await Users.load_from_file(self.subscribers_phone_data_path)
self.banned_users_phone = await Users.load_from_file(self.banned_users_phone_data_path)

self.admin = await Admin.load_from_file(admin_pass)
self.message_handler = MessageHandler()

Expand Down Expand Up @@ -80,8 +73,7 @@ async def subscribe(self, ctx: ChatContext) -> None:
self.logger.info(f"{subscriber_uuid} was not allowed to subscribe")
return

await self.subscribers.add(subscriber_uuid)
await self.subscribers_phone.add(ctx.message.source.number)
await self.subscribers.add(subscriber_uuid, ctx.message.source.number)
await self.reply_with_warn_on_failure(ctx, "Subscription successful!")
if self.expiration_time is not None:
await ctx.bot.set_expiration(subscriber_uuid, self.expiration_time)
Expand All @@ -105,7 +97,6 @@ async def unsubscribe(self, ctx: ChatContext) -> None:
return

await self.subscribers.remove(subscriber_uuid)
await self.subscribers_phone.remove(ctx.message.source.number)
await self.reply_with_warn_on_failure(ctx, "Successfully unsubscribed!")
self.logger.info(f"{subscriber_uuid} unsubscribed")
except Exception as e:
Expand Down Expand Up @@ -158,7 +149,6 @@ async def broadcast(self, ctx: ChatContext) -> None:
else:
self.logger.warning(f"Could not send message to {subscriber}")
await self.subscribers.remove(ctx.message.source.uuid)
await self.subscribers_phone.remove(ctx.message.source.number)

self.message_handler.delete_attachments(attachments)
except Exception as e:
Expand Down Expand Up @@ -331,7 +321,7 @@ async def msg_from_admin(self, ctx: ChatContext) -> None:
else:
confirmation = None
if confirmation != '!force':
warn_message = "User is not in subscribers list, use !reply !force to message them"
warn_message = "User is not in subscribers list, use !reply <uuid> !force to message them"
await self.reply_with_warn_on_failure(ctx, warn_message)
return

Expand All @@ -357,11 +347,10 @@ async def ban_user(self, ctx: ChatContext) -> None:
if not await self.is_user_admin(ctx, AdminCommandStrings.ban_subscriber):
return

user_phonenumber = self.subscribers.get_phone_number(user_id)
if user_id in self.subscribers:
await self.subscribers.remove(user_id)
await self.subscribers_phone.remove(ctx.message.source.number)
await self.banned_users.add(user_id)
await self.banned_users_phone.add(ctx.message.source.number)
await self.banned_users.add(user_id, user_phonenumber)

await ctx.bot.send_message(user_id, 'You have been banned')
await self.reply_with_warn_on_failure(ctx, "Successfully banned user")
Expand All @@ -386,7 +375,6 @@ async def lift_ban_user(self, ctx: ChatContext) -> None:

if user_id in self.banned_users:
await self.banned_users.remove(user_id)
await self.banned_users_phone.remove(ctx.message.source.number)
else:
await self.reply_with_warn_on_failure(ctx, "Could not lift the ban because the user was not banned")
self.logger.info(f"Could not lift the ban of {user_id} because the user was not banned")
Expand Down
41 changes: 24 additions & 17 deletions signalblast/users.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,49 @@
import csv
import os
from typing import Optional


class Users():
_uuid_str = 'uuid'
_phone_number_str = 'phone_number'

def __init__(self, save_path: str) -> None:
self.save_path = save_path
self.data = set()
self.data = dict()

async def add(self, element: str) -> None:
self.data.add(element)
async def add(self, uuid: str, phone_number: str) -> None:
self.data[uuid] = phone_number
await self.save_to_file()

async def remove(self, element: str) -> None:
self.data.remove(element)
async def remove(self, uuid: str) -> None:
del self.data[uuid]
await self.save_to_file()

async def save_to_file(self):
with open(self.save_path, "w") as f:
for i, subscriber in enumerate(self.data):
if i < len(self.data) - 1:
f.write(subscriber + '\n')
else:
f.write(subscriber)
csv_writer = csv.DictWriter(f, fieldnames=[self._uuid_str, self._phone_number_str])
csv_writer.writeheader()
for uuid, phone_number in self.data.items():
csv_writer.writerow({self._uuid_str: uuid, self._phone_number_str: phone_number})

@staticmethod
async def _load_from_file(save_path) -> 'Users':
users = Users(save_path)
with open(save_path, "r") as f:
lines = f.readlines()
for line in lines:
users.data.add(line.rstrip())
csv_reader = csv.DictReader(f)
for line in csv_reader:
users.data[line[Users._uuid_str]] = line[Users._phone_number_str]
return users

def get_phone_number(self, uuid: str) -> Optional[str]:
return self.data.get(uuid)

def __iter__(self):
for user in self.data:
yield user
for uuid in self.data:
yield uuid

def __contains__(self, user: str):
return user in self.data
def __contains__(self, uuid: str):
return uuid in self.data

def __len__(self):
return len(self.data)
Expand Down

0 comments on commit 3c986dc

Please sign in to comment.