Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions labconnect/main/auth_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import uuid4

from flask import current_app, make_response, redirect, request
from flask_jwt_extended import create_access_token
from flask_jwt_extended import create_access_token, get_jwt_identity, jwt_required
from onelogin.saml2.auth import OneLogin_Saml2_Auth

from labconnect import db
Expand All @@ -14,31 +14,36 @@
temp_codes = {}


def generate_temporary_code(user_email: str) -> str:
def generate_temporary_code(user_email: str, registered: bool) -> str:
# Generate a unique temporary code
code = str(uuid4())
expires_at = datetime.now() + timedelta(seconds=5) # expires in 5 seconds
temp_codes[code] = {"email": user_email, "expires_at": expires_at}
temp_codes[code] = {
"email": user_email,
"expires_at": expires_at,
"registered": registered,
}
return code


def validate_code_and_get_user_email(code: str) -> str | None:
def validate_code_and_get_user_email(code: str) -> tuple[str | None, bool | None]:
token_data = temp_codes.get(code, {})
if not token_data:
return None

user_email = token_data.get("email", None)
expire = token_data.get("expires_at", None)
registered = token_data.get("registered", None)

if user_email and expire and expire > datetime.now():
# If found, delete the code to prevent reuse
del temp_codes[code]
return user_email
return user_email, registered
elif expire:
# If the code has expired, delete it
del temp_codes[code]

return None
return None, None


@main_blueprint.get("/login")
Expand Down Expand Up @@ -70,36 +75,45 @@ def saml_callback():
errors = auth.get_errors()

if not errors:
registered = True
user_info = auth.get_attributes()
# user_id = auth.get_nameid()

data = db.session.execute(db.select(User).where(User.email == "email")).scalar()

# User doesn't exist, create a new user
if data is None:

# TODO: add data
user = User(
# email=email,
# first_name=first_name,
# last_name=last_name,
# preferred_name=json_request_data.get("preferred_name", None),
# class_year=class_year,
)

db.session.add(user)
db.session.commit()

registered = False
# Generate JWT
# token = create_access_token(identity=[user_id, datetime.now()])
code = generate_temporary_code(user_info["email"][0])
code = generate_temporary_code(user_info["email"][0], registered)

# Send the JWT to the frontend
return redirect(f"{current_app.config['FRONTEND_URL']}/callback/?code={code}")

return {"errors": errors}, 500


@main_blueprint.post("/register")
@jwt_required()
def registerUser():
user_id = get_jwt_identity()

user = db.session.execute(db.select(User).where(User.email == user_id))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be making a new User instead of requesting an existing one, all logic for making a user should be here.


# Gather the new user's information
json_data = request.get_json()
user.first_name = json_data.get("first_name")
user.last_name = json_data.get("last_name")
user.preferred_name = json_data.get("preferred_name")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some of these we should add default values for example preferred_name should be blank if empty I believe

user.class_year = json_data.get("class_year")
user.profile_picture = json_data.get("profile_pictures")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not profile picture is added we should add the default we have been using for test users

user.website = json_data.get("website")
user.description = json_data.get("description")

return {"msg": "Information added"}


@main_blueprint.post("/token")
def tokenRoute():
if request.json is None or request.json.get("code", None) is None:
Expand All @@ -108,13 +122,13 @@ def tokenRoute():
code = request.json["code"]
if code is None:
return {"msg": "Missing code in request"}, 400
user_email = validate_code_and_get_user_email(code)
user_email, registered = validate_code_and_get_user_email(code)

if user_email is None:
return {"msg": "Invalid code"}, 400

token = create_access_token(identity=[user_email, datetime.now()])
return {"token": token}
return {"token": token, "registered": registered}


@main_blueprint.get("/metadata/")
Expand Down