diff --git a/.gitignore b/.gitignore index 748c3b62..cb6c53bb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ env/ *.db .coverage *.vscode +.env \ No newline at end of file diff --git a/config.py b/config.py index 466c554d..b2716f9a 100644 --- a/config.py +++ b/config.py @@ -1,40 +1,51 @@ # Import os -import os +from os import getenv, path from datetime import timedelta -basedir = os.path.abspath(os.path.dirname(__file__)) +from dotenv import load_dotenv + +basedir = path.abspath(path.dirname(__file__)) class Config: + load_dotenv() + # Configuration - SECRET_KEY = os.environ.get("SECRET_KEY", os.urandom(32)) - JWT_SECRET_KEY = os.environ.get("JWT_SECRET_KEY", os.urandom(32)) - JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=1) + SECRET_KEY = getenv("SECRET_KEY", "main-secret") TESTING = False DEBUG = False PREFERRED_URL_SCHEME = "https" - SAML_CONFIG = os.path.join(basedir, "config/saml/") - FRONTEND_URL = os.environ.get("FRONTEND_URL", "http://localhost:3000") + SAML_CONFIG = path.join(basedir, "config/saml/") + FRONTEND_URL = getenv("FRONTEND_URL", "http://localhost:3000") - SENTRY_DSN = os.environ.get("SENTRY_DSN", "") - SENTRY_TRACES_SAMPLE_RATE = float(os.environ.get("SENTRY_TRACES_SAMPLE_RATE", 1.0)) - SENTRY_PROFILES_SAMPLE_RATE = float( - os.environ.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0) - ) + SENTRY_DSN = getenv("SENTRY_DSN", "") + SENTRY_TRACES_SAMPLE_RATE = float(getenv("SENTRY_TRACES_SAMPLE_RATE", 1.0)) + SENTRY_PROFILES_SAMPLE_RATE = float(getenv("SENTRY_PROFILES_SAMPLE_RATE", 1.0)) - SQLALCHEMY_DATABASE_URI = os.environ.get( + SQLALCHEMY_DATABASE_URI = getenv( "DB", "postgresql+psycopg2://postgres:root@localhost/labconnect" ) - TOKEN_BLACKLIST = set() + JWT_SECRET_KEY = getenv("JWT_SECRET_KEY", "jwt-secret") + JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=1) + JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=7) + JWT_SESSION_COOKIE = timedelta(hours=1) + JWT_TOKEN_LOCATION = ["cookies"] + JWT_COOKIE_CSRF_PROTECT = True + JWT_CSRF_CHECK_FORM = True + JWT_COOKIE_SECURE = True + JWT_COOKIE_SAMESITE = "Strict" + JWT_ACCESS_COOKIE_NAME = "access_token" + JWT_REFRESH_COOKIE_NAME = "refresh_token" class TestingConfig(Config): TESTING = True DEBUG = True + JWT_COOKIE_SECURE = False class ProductionConfig(Config): diff --git a/db_init.py b/db_init.py index 8998f689..0cb2225c 100644 --- a/db_init.py +++ b/db_init.py @@ -29,6 +29,7 @@ UserDepartments, UserMajors, UserSavedOpportunities, + Codes, ) app = create_app() @@ -115,8 +116,20 @@ lab_manager_rows = ( ("led", "Duy", "Le", "CSCI", "database database database"), - ("turner", "Wes","Turner","CSCI","open source stuff is cool",), - ("kuzmin","Konstantine","Kuzmin","CSCI","java, psoft, etc.",), + ( + "turner", + "Wes", + "Turner", + "CSCI", + "open source stuff is cool", + ), + ( + "kuzmin", + "Konstantine", + "Kuzmin", + "CSCI", + "java, psoft, etc.", + ), ("goldd", "David", "Goldschmidt", "CSCI", "VIM master"), ("rami", "Rami", "Rami", "MTLE", "cubes are cool"), ("holm", "Mark", "Holmes", "MATH", "all about that math"), @@ -399,7 +412,12 @@ db.session.add(row) db.session.commit() - participates_rows = (("cenzar", 1),("cenzar", 2),("test", 3),("test", 4),) + participates_rows = ( + ("cenzar", 1), + ("cenzar", 2), + ("test", 3), + ("test", 4), + ) for r in participates_rows: row = Participates() @@ -409,7 +427,7 @@ db.session.add(row) db.session.commit() - tables = [ + tables = [ ClassYears, Courses, Leads, @@ -425,7 +443,7 @@ UserCourses, UserDepartments, UserMajors, - UserSavedOpportunities + UserSavedOpportunities, ] for table in tables: diff --git a/docs/database_docs/Labconnect_DB.png b/docs/database_docs/Labconnect_DB.png index bcd496e7..83a3259e 100644 Binary files a/docs/database_docs/Labconnect_DB.png and b/docs/database_docs/Labconnect_DB.png differ diff --git a/labconnect/helpers.py b/labconnect/helpers.py index a39a15da..c9d42b4d 100644 --- a/labconnect/helpers.py +++ b/labconnect/helpers.py @@ -37,8 +37,15 @@ class LocationEnum(EnumPython): REMOTE = "Remote" -class OrJSONProvider(JSONProvider): +class LabManagerTypeEnum(EnumPython): + PI = "Principal Investigator" + CO_PI = "Co-Principal Investigator" + LAB_MANAGER = "Lab Manager" + POST_DOC = "Post Doctoral Researcher" + GRAD_STUDENT = "Graduate Student" + +class OrJSONProvider(JSONProvider): @staticmethod def dumps(obj, *, option=None, **kwargs): if option is None: diff --git a/labconnect/main/auth_routes.py b/labconnect/main/auth_routes.py index 3795f374..6092b720 100644 --- a/labconnect/main/auth_routes.py +++ b/labconnect/main/auth_routes.py @@ -2,8 +2,17 @@ from uuid import uuid4 from flask import current_app, make_response, redirect, request, abort -from flask_jwt_extended import create_access_token +from flask_jwt_extended import ( + get_jwt_identity, + create_access_token, + create_refresh_token, + set_access_cookies, + set_refresh_cookies, + unset_jwt_cookies, + jwt_required, +) from onelogin.saml2.auth import OneLogin_Saml2_Auth +from werkzeug.wrappers.response import Response from labconnect import db from labconnect.helpers import prepare_flask_request @@ -35,7 +44,7 @@ def generate_temporary_code(user_email: str, registered: bool) -> str: return code -def validate_code_and_get_user_email(code: str) -> tuple[str | None, bool | None]: +def validate_code_and_get_user_email(code: str) -> tuple[str, bool] | tuple[None, None]: code_data = db.session.execute(db.select(Codes).where(Codes.code == code)).scalar() if not code_data: return None, None @@ -44,20 +53,20 @@ def validate_code_and_get_user_email(code: str) -> tuple[str | None, bool | None expire = code_data.expires_at registered = code_data.registered - if user_email and expire and expire > datetime.now(): - # If found, delete the code to prevent reuse - db.session.delete(code_data) - return user_email, registered - elif expire: - # If the code has expired, delete it - db.session.delete(code_data) + if user_email and expire: + if expire > datetime.now(): + # If found, delete the code to prevent reuse + db.session.delete(code_data) + return user_email, registered + else: + # If the code has expired, delete it + db.session.delete(code_data) return None, None @main_blueprint.get("/login") -def saml_login(): - +def saml_login() -> Response: # In testing skip RPI login purely for local development if current_app.config["TESTING"] and ( current_app.config["FRONTEND_URL"] == "http://localhost:3000" @@ -76,7 +85,7 @@ def saml_login(): @main_blueprint.post("/callback") -def saml_callback(): +def saml_callback() -> Response: # Process SAML response req = prepare_flask_request(request) auth = OneLogin_Saml2_Auth(req, custom_base_path=current_app.config["SAML_CONFIG"]) @@ -100,12 +109,33 @@ def saml_callback(): return redirect(f"{current_app.config['FRONTEND_URL']}/callback/?code={code}") error_reason = auth.get_last_error_reason() - return {"errors": errors, "error_reason": error_reason}, 500 + return make_response({"errors": errors, "error_reason": error_reason}, 500) -@main_blueprint.post("/register") -def registerUser(): +@main_blueprint.post("/token") +def tokenRoute() -> Response: + if request.json is None or request.json.get("code", None) is None: + return make_response({"msg": "Missing JSON body in request"}, 400) + + # Validate the temporary code + code = request.json["code"] + if code is None: + return make_response({"msg": "Missing code in request"}, 400) + user_email, registered = validate_code_and_get_user_email(code) + if user_email is None: + return make_response({"msg": "Invalid code"}, 400) + + access_token = create_access_token(identity=user_email) + refresh_token = create_refresh_token(identity=user_email) + resp = make_response({"registered": registered}) + set_access_cookies(resp, access_token) + set_refresh_cookies(resp, refresh_token) + return resp + + +@main_blueprint.post("/register") +def registerUser() -> Response: # Gather the new user's information json_data = request.get_json() if not json_data: @@ -162,28 +192,11 @@ def registerUser(): db.session.add(management_permissions) db.session.commit() - return {"msg": "New user added"} - - -@main_blueprint.post("/token") -def tokenRoute(): - if request.json is None or request.json.get("code", None) is None: - return {"msg": "Missing JSON body in request"}, 400 - # Validate the temporary code - code = request.json["code"] - if code is None: - return {"msg": "Missing code in request"}, 400 - user_email, registered = validate_code_and_get_user_email(code) - - if user_email is None: - return {"msg": "Invalid code"}, 400 - - token = create_access_token(identity=[user_email, datetime.now()]) - return {"token": token, "registered": registered} + return make_response({"msg": "New user added"}) @main_blueprint.get("/metadata/") -def metadataRoute(): +def metadataRoute() -> Response: req = prepare_flask_request(request) auth = auth = OneLogin_Saml2_Auth( req, custom_base_path=current_app.config["SAML_CONFIG"] @@ -200,7 +213,25 @@ def metadataRoute(): return resp +@main_blueprint.get("/authcheck") +@jwt_required() +def authcheck() -> Response: + return make_response({"msg": "authenticated"}) + + +@main_blueprint.get("/token/refresh") +@jwt_required(refresh=True) +def refresh() -> Response: + # Refreshing expired Access token + user_id = get_jwt_identity() + access_token = create_access_token(identity=str(user_id)) + resp = make_response({"msg": "refresh successful"}) + set_access_cookies(resp, access_token) + return resp + + @main_blueprint.get("/logout") -def logout(): - # TODO: add token to blacklist - return {"msg": "logout successful"} +def logout() -> Response: + resp = make_response({"msg": "logout successful"}) + unset_jwt_cookies(resp) + return resp diff --git a/labconnect/main/routes.py b/labconnect/main/routes.py index d216cbb8..1ab090f8 100644 --- a/labconnect/main/routes.py +++ b/labconnect/main/routes.py @@ -12,13 +12,14 @@ ClassYears, UserDepartments, Majors, + Courses, ) from . import main_blueprint @main_blueprint.get("/") -def index(): +def index() -> dict[str, str]: return {"Hello": "There"} @@ -104,8 +105,9 @@ def profile(): User.website, User.lab_manager_id, User.id, + User.pronouns, ) - .where(User.email == user_id[0]) + .where(User.email == user_id) .join(UserDepartments, UserDepartments.user_id == User.id) .join(RPIDepartments, UserDepartments.department_id == RPIDepartments.id) ).first() @@ -123,6 +125,7 @@ def profile(): "department": data[4], "description": data[5], "website": data[6], + "pronouns": data[9], } return result @@ -140,6 +143,7 @@ def getProfessorProfile(id: str): RPIDepartments.name, User.description, User.website, + User.pronouns, ) .where(User.id == id) .join(LabManager, User.lab_manager_id == LabManager.id) @@ -155,6 +159,7 @@ def getProfessorProfile(id: str): "department": data[4], "description": data[5], "website": data[6], + "pronouns": data[7], } return result @@ -235,84 +240,36 @@ def years() -> list[int]: return result -# @main_blueprint.get("/courses") -# def courses() -> list[Any]: -# if not request.data: -# abort(400) +@main_blueprint.get("/courses") +def courses() -> list[str]: + if not request.data: + abort(400) -# json_request_data = request.get_json() + json_request_data = request.get_json() -# if not json_request_data: -# abort(400) + if not json_request_data: + abort(400) -# partial_key = json_request_data.get("input", None) + partial_key = json_request_data.get("input", None) -# if not partial_key or not isinstance(partial_key, str): -# abort(400) + if not partial_key or not isinstance(partial_key, str): + abort(400) -# data = db.session.execute( -# db.select(Courses) -# .order_by(Courses.code) -# .where( -# (Courses.code.ilike(f"%{partial_key}%")) -# | (Courses.name.ilike(f"%{partial_key}%")) -# ) -# ).scalars() - -# if not data: -# abort(404) - -# result = [course.to_dict() for course in data] - -# if result == []: -# abort(404) + data = db.session.execute( + db.select(Courses) + .order_by(Courses.code) + .where( + (Courses.code.ilike(f"%{partial_key}%")) + | (Courses.name.ilike(f"%{partial_key}%")) + ) + ).scalars() -# return result + if not data: + abort(404) + result = [course.to_dict() for course in data] -# @main_blueprint.get("/user") -# def user(): -# if not request.data: -# abort(400) - -# id = request.get_json().get("id", None) - -# if not id: -# abort(400) - -# # Query for user -# user = db.first_or_404(db.select(User).where(User.id == id)) -# result = user.to_dict() - -# # Query for user's department(s) -# user_departments = db.session.execute( -# db.select(UserDepartments).where(UserDepartments.user_id == id) -# ).scalars() -# result["departments"] = [dept.to_dict() for dept in user_departments] - -# # Query for user's major(s) -# user_majors = db.session.execute( -# db.select(UserMajors).where(UserMajors.user_id == id) -# ).scalars() -# result["majors"] = [major.to_dict() for major in user_majors] - -# # Query for user's courses -# user_courses = db.session.execute( -# db.select(UserCourses) -# .order_by(UserCourses.in_progress) -# .where(UserCourses.user_id == id) -# ).scalars() -# result["courses"] = [course.to_dict() for course in user_courses] - -# # Query for user's opportunities -# user_opportunities = db.session.execute( -# db.select(Opportunities, Participates) -# .where(Participates.user_id == id) -# .join(Opportunities, Participates.opportunity_id == Opportunities.id) -# .order_by(Opportunities.active.desc()) -# ).scalars() -# result["opportunities"] = [ -# opportunity.to_dict() for opportunity in user_opportunities -# ] + if result == []: + abort(404) -# return result + return result diff --git a/labconnect/models.py b/labconnect/models.py index b7a192b3..8c3d4282 100644 --- a/labconnect/models.py +++ b/labconnect/models.py @@ -2,7 +2,12 @@ from sqlalchemy.dialects.postgresql import TSVECTOR from labconnect import db -from labconnect.helpers import CustomSerializerMixin, LocationEnum, SemesterEnum +from labconnect.helpers import ( + CustomSerializerMixin, + LocationEnum, + SemesterEnum, + LabManagerTypeEnum, +) # DD - Entities @@ -28,6 +33,7 @@ class User(db.Model, CustomSerializerMixin): first_name = db.Column(db.String(50), nullable=False, unique=False) last_name = db.Column(db.String(200), nullable=False, unique=False) preferred_name = db.Column(db.String(50), nullable=True, unique=False) + pronouns = db.Column(db.String(25), nullable=True, unique=False) phone_number = db.Column(db.String(15), nullable=True, unique=False) website = db.Column(db.String(512), nullable=True, unique=False) description = db.Column(db.String(4096), nullable=True, unique=False) @@ -84,6 +90,7 @@ class LabManager(db.Model, CustomSerializerMixin): serialize_rules = () id = db.Column(db.Integer, primary_key=True, autoincrement=True) + manager_type = db.Column(Enum(LabManagerTypeEnum), nullable=True, unique=False) department_id = db.Column(db.String(4), db.ForeignKey("rpi_departments.id")) user = db.relationship("User", back_populates="lab_manager") @@ -407,4 +414,4 @@ class Codes(db.Model): code = db.Column(db.String(64), primary_key=True) email = db.Column(db.String(64), nullable=False) expires_at = db.Column(db.DateTime, nullable=False) - registered = db.Column(db.Boolean, nullable=False) \ No newline at end of file + registered = db.Column(db.Boolean, nullable=False) diff --git a/migrations/versions/0e1d1657b500_.py b/migrations/versions/0e1d1657b500_.py new file mode 100644 index 00000000..010fe7a7 --- /dev/null +++ b/migrations/versions/0e1d1657b500_.py @@ -0,0 +1,56 @@ +"""empty message + +Revision ID: 0e1d1657b500 +Revises: +Create Date: 2025-02-14 16:31:55.909541 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "0e1d1657b500" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("lab_manager", schema=None) as batch_op: + op.execute( + "CREATE TYPE labmanagertypeenum AS ENUM ('PI', 'CO_PI', 'LAB_MANAGER', 'POST_DOC', 'GRAD_STUDENT')" + ) + batch_op.add_column( + sa.Column( + "manager_type", + sa.Enum( + "PI", + "CO_PI", + "LAB_MANAGER", + "POST_DOC", + "GRAD_STUDENT", + name="labmanagertypeenum", + ), + nullable=False, + server_default="LAB_MANAGER", + ) + ) + + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.add_column(sa.Column("pronouns", sa.String(length=25), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.drop_column("pronouns") + + with op.batch_alter_table("lab_manager", schema=None) as batch_op: + batch_op.drop_column("manager_type") + + # ### end Alembic commands ### diff --git a/requirements.txt b/requirements.txt index 253bfce0..916cf1e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ psycopg2-binary==2.9.9 PyJWT==2.10.1 pytest==8.3.4 pytest-cov==6.0.0 +python-dotenv==1.0.1 python3-saml==1.16.0 pytz==2024.2 ruff==0.9.5