Skip to content

Commit 430375c

Browse files
Merge branch 'main' into 185-parameterize-tests
2 parents 9a7063a + 9b01757 commit 430375c

File tree

5 files changed

+512
-420
lines changed

5 files changed

+512
-420
lines changed

db_init.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
db.session.add(row)
9090
db.session.commit()
9191

92-
class_years_rows = (2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031)
92+
class_years_rows = (2025, 2026, 2027, 2028, 2029, 2030, 2031)
9393

9494
for row_item in class_years_rows:
9595
row = ClassYears(class_year=row_item, active=True)
@@ -182,7 +182,7 @@
182182
False,
183183
True,
184184
SemesterEnum.SPRING,
185-
2024,
185+
2025,
186186
date.today(),
187187
True,
188188
datetime.now(),
@@ -198,7 +198,7 @@
198198
True,
199199
True,
200200
SemesterEnum.SPRING,
201-
2024,
201+
2025,
202202
date.today(),
203203
True,
204204
datetime.now(),
@@ -214,7 +214,7 @@
214214
True,
215215
True,
216216
SemesterEnum.FALL,
217-
2024,
217+
2025,
218218
date.today(),
219219
True,
220220
datetime.now(),
@@ -230,7 +230,7 @@
230230
True,
231231
True,
232232
SemesterEnum.SUMMER,
233-
2024,
233+
2025,
234234
date.today(),
235235
True,
236236
datetime.now(),
@@ -246,10 +246,10 @@
246246
True,
247247
False,
248248
SemesterEnum.FALL,
249-
2024,
250-
"2024-10-31",
249+
2025,
250+
"2025-10-31",
251251
True,
252-
"2024-10-10T10:30:00",
252+
"2025-10-10T10:30:00",
253253
LocationEnum.JROWL,
254254
),
255255
)
@@ -331,7 +331,7 @@
331331
db.session.add(row)
332332
db.session.commit()
333333

334-
recommends_class_years_rows = ((2, 2024), (2, 2025), (2, 2026), (1, 2027))
334+
recommends_class_years_rows = ((3, 2025), (2, 2025), (2, 2026), (1, 2027))
335335

336336
for r in recommends_class_years_rows:
337337
row = RecommendsClassYears(opportunity_id=r[0], class_year=r[1])
-234 Bytes
Loading

labconnect/main/auth_routes.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,67 @@
11
from datetime import datetime, timedelta
22
from uuid import uuid4
33

4-
from flask import current_app, make_response, redirect, request
4+
from flask import current_app, make_response, redirect, request, abort
55
from flask_jwt_extended import create_access_token
66
from onelogin.saml2.auth import OneLogin_Saml2_Auth
77

88
from labconnect import db
99
from labconnect.helpers import prepare_flask_request
10-
from labconnect.models import User
10+
from labconnect.models import (
11+
User,
12+
UserCourses,
13+
UserDepartments,
14+
UserMajors,
15+
ManagementPermissions,
16+
)
1117

1218
from . import main_blueprint
1319

1420
temp_codes = {}
1521

1622

17-
def generate_temporary_code(user_email: str) -> str:
23+
def generate_temporary_code(user_email: str, registered: bool) -> str:
1824
# Generate a unique temporary code
1925
code = str(uuid4())
2026
expires_at = datetime.now() + timedelta(seconds=5) # expires in 5 seconds
21-
temp_codes[code] = {"email": user_email, "expires_at": expires_at}
27+
temp_codes[code] = {
28+
"email": user_email,
29+
"expires_at": expires_at,
30+
"registered": registered,
31+
}
2232
return code
2333

2434

25-
def validate_code_and_get_user_email(code: str) -> str | None:
35+
def validate_code_and_get_user_email(code: str) -> tuple[str | None, bool | None]:
2636
token_data = temp_codes.get(code, {})
2737
if not token_data:
2838
return None
2939

3040
user_email = token_data.get("email", None)
3141
expire = token_data.get("expires_at", None)
42+
registered = token_data.get("registered", None)
3243

3344
if user_email and expire and expire > datetime.now():
3445
# If found, delete the code to prevent reuse
3546
del temp_codes[code]
36-
return user_email
47+
return user_email, registered
3748
elif expire:
3849
# If the code has expired, delete it
3950
del temp_codes[code]
4051

41-
return None
52+
return None, None
4253

4354

4455
@main_blueprint.get("/login")
4556
def saml_login():
4657

4758
# In testing skip RPI login purely for local development
48-
if (
49-
current_app.config["TESTING"]
50-
and current_app.config["FRONTEND_URL"] == "http://localhost:3000"
59+
if current_app.config["TESTING"] and (
60+
current_app.config["FRONTEND_URL"] == "http://localhost:3000"
61+
or current_app.config["FRONTEND_URL"] == "http://127.0.0.1:3000"
5162
):
5263
# Generate JWT
53-
code = generate_temporary_code("[email protected]")
64+
code = generate_temporary_code("[email protected]", True)
5465

5566
# Send the JWT to the frontend
5667
return redirect(f"{current_app.config['FRONTEND_URL']}/callback/?code={code}")
@@ -70,36 +81,82 @@ def saml_callback():
7081
errors = auth.get_errors()
7182

7283
if not errors:
84+
registered = True
7385
user_info = auth.get_attributes()
7486
# user_id = auth.get_nameid()
7587

7688
data = db.session.execute(db.select(User).where(User.email == "email")).scalar()
7789

7890
# User doesn't exist, create a new user
7991
if data is None:
80-
81-
# TODO: add data
82-
user = User(
83-
# email=email,
84-
# first_name=first_name,
85-
# last_name=last_name,
86-
# preferred_name=json_request_data.get("preferred_name", None),
87-
# class_year=class_year,
88-
)
89-
90-
db.session.add(user)
91-
db.session.commit()
92-
92+
registered = False
9393
# Generate JWT
9494
# token = create_access_token(identity=[user_id, datetime.now()])
95-
code = generate_temporary_code(user_info["email"][0])
95+
code = generate_temporary_code(user_info["email"][0], registered)
9696

9797
# Send the JWT to the frontend
9898
return redirect(f"{current_app.config['FRONTEND_URL']}/callback/?code={code}")
9999

100100
return {"errors": errors}, 500
101101

102102

103+
@main_blueprint.post("/register")
104+
def registerUser():
105+
106+
# Gather the new user's information
107+
json_data = request.get_json()
108+
if not json_data:
109+
abort(400)
110+
111+
user = User(
112+
email=json_data.get("email"),
113+
first_name=json_data.get("first_name"),
114+
last_name=json_data.get("last_name"),
115+
preferred_name=json_data.get("preferred_name", ""),
116+
class_year=json_data.get("class_year", ""),
117+
profile_picture=json_data.get(
118+
"profile_picture", "https://www.svgrepo.com/show/206842/professor.svg"
119+
),
120+
website=json_data.get("website", ""),
121+
description=json_data.get("description", ""),
122+
)
123+
db.session.add(user)
124+
db.session.commit()
125+
126+
# Add UserDepartments if provided
127+
if json_data.get("departments"):
128+
for department_id in json_data["departments"]:
129+
user_department = UserDepartments(
130+
user_id=user.id, department_id=department_id
131+
)
132+
db.session.add(user_department)
133+
134+
# Additional auxiliary records (majors, courses, etc.)
135+
if json_data.get("majors"):
136+
for major_id in json_data["majors"]:
137+
user_major = UserMajors(user_id=user.id, major_id=major_id)
138+
db.session.add(user_major)
139+
# Add Courses if provided
140+
if json_data.get("courses"):
141+
for course_id in json_data["courses"]:
142+
user_course = UserCourses(user_id=user.id, course_id=course_id)
143+
db.session.add(user_course)
144+
145+
# Add ManagementPermissions if provided
146+
if json_data.get("permissions"):
147+
permissions = json_data["permissions"]
148+
management_permissions = ManagementPermissions(
149+
user_id=user.id,
150+
super_admin=permissions.get("super_admin", False),
151+
admin=permissions.get("admin", False),
152+
moderator=permissions.get("moderator", False),
153+
)
154+
db.session.add(management_permissions)
155+
156+
db.session.commit()
157+
return {"msg": "New user added"}
158+
159+
103160
@main_blueprint.post("/token")
104161
def tokenRoute():
105162
if request.json is None or request.json.get("code", None) is None:
@@ -108,13 +165,13 @@ def tokenRoute():
108165
code = request.json["code"]
109166
if code is None:
110167
return {"msg": "Missing code in request"}, 400
111-
user_email = validate_code_and_get_user_email(code)
168+
user_email, registered = validate_code_and_get_user_email(code)
112169

113170
if user_email is None:
114171
return {"msg": "Invalid code"}, 400
115172

116173
token = create_access_token(identity=[user_email, datetime.now()])
117-
return {"token": token}
174+
return {"token": token, "registered": registered}
118175

119176

120177
@main_blueprint.get("/metadata/")

0 commit comments

Comments
 (0)