diff --git a/db_init.py b/db_init.py index 01af2ba1..bfdd6aa9 100644 --- a/db_init.py +++ b/db_init.py @@ -8,7 +8,7 @@ import sys import requests -import validators +import re from datetime import date, datetime @@ -34,6 +34,8 @@ Codes, ) +url_regex = re.compile(r"^(https?|ftp)://[^\s/$.?#].[^\s]*$") + def fetch_json_data(json_url): response = requests.get(json_url) @@ -63,7 +65,10 @@ def insert_courses_from_json(session, courses_data): if existing_course.name != course_name: existing_course.name = course_name else: - new_courses.append(Courses(code=course_code, name=course_name)) + new_course = Courses() + new_course.code = course_code + new_course.name = course_name + new_courses.append(new_course) if new_courses: session.add_all(new_courses) @@ -91,9 +96,10 @@ def insert_schools_and_departments(session, schools_data): if school.description != school_description: school.description = school_description else: - new_schools.append( - RPISchools(name=school_name, description=school_description) - ) + new_school = RPISchools() + new_school.name = school_name + new_school.description = school_description + new_schools.append(new_school) for department_data in school_data.get("depts", []): department_id = department_data.get("code") @@ -110,27 +116,24 @@ def insert_schools_and_departments(session, schools_data): if department.school_id != school_name: department.school_id = school_name else: - new_depts.append( - RPIDepartments( - id=department_id, - name=department_name, - description=department_description, - school_id=school_name, - ) - ) + new_department = RPIDepartments() + new_department.id = department_id + new_department.name = department_name + new_department.description = department_description + new_department.school_id = school_name + new_depts.append(new_department) if new_schools or new_depts: session.add_all(new_schools + new_depts) session.commit() -def main(): - app = create_app() - +def main() -> None: if len(sys.argv) < 2: sys.exit("No argument or existing argument found") if sys.argv[1] == "start": + app = create_app() with app.app_context(): if db.inspect(db.engine).get_table_names(): print("Tables already exist.") @@ -141,6 +144,7 @@ def main(): db.create_all() elif sys.argv[1] == "clear": + app = create_app() with app.app_context(): db.drop_all() @@ -151,9 +155,10 @@ def main(): j_url = sys.argv[2] # Validate that j_url is a valid URL - if not validators.url(j_url): + if not url_regex.match(j_url): sys.exit("Error: Invalid URL provided.") + app = create_app() with app.app_context(): db.create_all() @@ -172,9 +177,10 @@ def main(): j_url = sys.argv[2] # Validate that j_url is a valid URL - if not validators.url(j_url): + if not url_regex.match(j_url): sys.exit("Error: Invalid URL provided.") + app = create_app() with app.app_context(): db.create_all() @@ -187,6 +193,7 @@ def main(): db.session.close() elif sys.argv[1] == "create": + app = create_app() with app.app_context(): db.create_all()