11from datetime import datetime , timedelta
22from uuid import uuid4
33
4- from flask import current_app , make_response , redirect , request
4+ from flask import current_app , make_response , redirect , request , abort
55from flask_jwt_extended import create_access_token
66from onelogin .saml2 .auth import OneLogin_Saml2_Auth
77
88from labconnect import db
99from labconnect .helpers import prepare_flask_request
10- from labconnect .models import User
10+ from labconnect .models import (
11+ User ,
12+ UserCourses ,
13+ UserDepartments ,
14+ UserMajors ,
15+ ManagementPermissions ,
16+ )
1117
1218from . import main_blueprint
1319
1420temp_codes = {}
1521
1622
17- def generate_temporary_code (user_email : str ) -> str :
23+ def generate_temporary_code (user_email : str , registered : bool ) -> str :
1824 # Generate a unique temporary code
1925 code = str (uuid4 ())
2026 expires_at = datetime .now () + timedelta (seconds = 5 ) # expires in 5 seconds
21- temp_codes [code ] = {"email" : user_email , "expires_at" : expires_at }
27+ temp_codes [code ] = {
28+ "email" : user_email ,
29+ "expires_at" : expires_at ,
30+ "registered" : registered ,
31+ }
2232 return code
2333
2434
25- def validate_code_and_get_user_email (code : str ) -> str | None :
35+ def validate_code_and_get_user_email (code : str ) -> tuple [ str | None , bool | None ] :
2636 token_data = temp_codes .get (code , {})
2737 if not token_data :
2838 return None
2939
3040 user_email = token_data .get ("email" , None )
3141 expire = token_data .get ("expires_at" , None )
42+ registered = token_data .get ("registered" , None )
3243
3344 if user_email and expire and expire > datetime .now ():
3445 # If found, delete the code to prevent reuse
3546 del temp_codes [code ]
36- return user_email
47+ return user_email , registered
3748 elif expire :
3849 # If the code has expired, delete it
3950 del temp_codes [code ]
4051
41- return None
52+ return None , None
4253
4354
4455@main_blueprint .get ("/login" )
4556def saml_login ():
4657
4758 # In testing skip RPI login purely for local development
48- if (
49- current_app .config ["TESTING" ]
50- and current_app .config ["FRONTEND_URL" ] == "http://localhost :3000"
59+ if current_app . config [ "TESTING" ] and (
60+ current_app .config ["FRONTEND_URL" ] == "http://localhost:3000"
61+ or current_app .config ["FRONTEND_URL" ] == "http://127.0.0.1 :3000"
5162 ):
5263 # Generate JWT
53- code = generate_temporary_code (
"[email protected] " )
64+ code = generate_temporary_code (
"[email protected] " , True )
5465
5566 # Send the JWT to the frontend
5667 return redirect (f"{ current_app .config ['FRONTEND_URL' ]} /callback/?code={ code } " )
@@ -70,36 +81,82 @@ def saml_callback():
7081 errors = auth .get_errors ()
7182
7283 if not errors :
84+ registered = True
7385 user_info = auth .get_attributes ()
7486 # user_id = auth.get_nameid()
7587
7688 data = db .session .execute (db .select (User ).where (User .email == "email" )).scalar ()
7789
7890 # User doesn't exist, create a new user
7991 if data is None :
80-
81- # TODO: add data
82- user = User (
83- # email=email,
84- # first_name=first_name,
85- # last_name=last_name,
86- # preferred_name=json_request_data.get("preferred_name", None),
87- # class_year=class_year,
88- )
89-
90- db .session .add (user )
91- db .session .commit ()
92-
92+ registered = False
9393 # Generate JWT
9494 # token = create_access_token(identity=[user_id, datetime.now()])
95- code = generate_temporary_code (user_info ["email" ][0 ])
95+ code = generate_temporary_code (user_info ["email" ][0 ], registered )
9696
9797 # Send the JWT to the frontend
9898 return redirect (f"{ current_app .config ['FRONTEND_URL' ]} /callback/?code={ code } " )
9999
100100 return {"errors" : errors }, 500
101101
102102
103+ @main_blueprint .post ("/register" )
104+ def registerUser ():
105+
106+ # Gather the new user's information
107+ json_data = request .get_json ()
108+ if not json_data :
109+ abort (400 )
110+
111+ user = User (
112+ email = json_data .get ("email" ),
113+ first_name = json_data .get ("first_name" ),
114+ last_name = json_data .get ("last_name" ),
115+ preferred_name = json_data .get ("preferred_name" , "" ),
116+ class_year = json_data .get ("class_year" , "" ),
117+ profile_picture = json_data .get (
118+ "profile_picture" , "https://www.svgrepo.com/show/206842/professor.svg"
119+ ),
120+ website = json_data .get ("website" , "" ),
121+ description = json_data .get ("description" , "" ),
122+ )
123+ db .session .add (user )
124+ db .session .commit ()
125+
126+ # Add UserDepartments if provided
127+ if json_data .get ("departments" ):
128+ for department_id in json_data ["departments" ]:
129+ user_department = UserDepartments (
130+ user_id = user .id , department_id = department_id
131+ )
132+ db .session .add (user_department )
133+
134+ # Additional auxiliary records (majors, courses, etc.)
135+ if json_data .get ("majors" ):
136+ for major_id in json_data ["majors" ]:
137+ user_major = UserMajors (user_id = user .id , major_id = major_id )
138+ db .session .add (user_major )
139+ # Add Courses if provided
140+ if json_data .get ("courses" ):
141+ for course_id in json_data ["courses" ]:
142+ user_course = UserCourses (user_id = user .id , course_id = course_id )
143+ db .session .add (user_course )
144+
145+ # Add ManagementPermissions if provided
146+ if json_data .get ("permissions" ):
147+ permissions = json_data ["permissions" ]
148+ management_permissions = ManagementPermissions (
149+ user_id = user .id ,
150+ super_admin = permissions .get ("super_admin" , False ),
151+ admin = permissions .get ("admin" , False ),
152+ moderator = permissions .get ("moderator" , False ),
153+ )
154+ db .session .add (management_permissions )
155+
156+ db .session .commit ()
157+ return {"msg" : "New user added" }
158+
159+
103160@main_blueprint .post ("/token" )
104161def tokenRoute ():
105162 if request .json is None or request .json .get ("code" , None ) is None :
@@ -108,13 +165,13 @@ def tokenRoute():
108165 code = request .json ["code" ]
109166 if code is None :
110167 return {"msg" : "Missing code in request" }, 400
111- user_email = validate_code_and_get_user_email (code )
168+ user_email , registered = validate_code_and_get_user_email (code )
112169
113170 if user_email is None :
114171 return {"msg" : "Invalid code" }, 400
115172
116173 token = create_access_token (identity = [user_email , datetime .now ()])
117- return {"token" : token }
174+ return {"token" : token , "registered" : registered }
118175
119176
120177@main_blueprint .get ("/metadata/" )
0 commit comments