From 85a8083e3f45cf7b8d93dd746d0c18d76613635d Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:30:14 +0100 Subject: [PATCH 01/29] feat: add UserSession model --- api/v1/models/__init__.py | 2 +- api/v1/models/user.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/api/v1/models/__init__.py b/api/v1/models/__init__.py index a39ec4bae..0090f357c 100644 --- a/api/v1/models/__init__.py +++ b/api/v1/models/__init__.py @@ -33,4 +33,4 @@ from api.v1.models.wishlist import Wishlist from api.v1.models.totp_device import TOTPDevice from api.v1.models.bookmark import Bookmark - +from api.v1.models.session import UserSession diff --git a/api/v1/models/user.py b/api/v1/models/user.py index 85ad5fa12..fe028c875 100644 --- a/api/v1/models/user.py +++ b/api/v1/models/user.py @@ -34,6 +34,7 @@ class User(BaseTableModel): profile = relationship( "Profile", uselist=False, back_populates="user", cascade="all, delete-orphan" ) + organisations = relationship( "Organisation", secondary=user_organisation_roles, back_populates="users" ) @@ -117,6 +118,9 @@ class User(BaseTableModel): "Bookmark", back_populates="user", cascade="delete" ) + sessions = relationship( + "UserSession", back_populates="user", cascade="all, delete-orphan" + ) def to_dict(self): obj_dict = super().to_dict() obj_dict.pop("password") From aa7af95fedb290e1ac3189afca94cb672b25b623 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:34:35 +0100 Subject: [PATCH 02/29] update: make /register and /login endpoint to create a user session on request --- api/v1/routes/auth.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py index 2204b8876..1aa78a1b1 100644 --- a/api/v1/routes/auth.py +++ b/api/v1/routes/auth.py @@ -1,4 +1,5 @@ import logging +import datetime as dt from datetime import timedelta from fastapi.responses import JSONResponse from jose import ExpiredSignatureError, JWTError @@ -21,6 +22,8 @@ from api.core.dependencies.email_sender import send_email from api.utils.success_response import auth_response, success_response from api.utils.send_mail import send_magic_link +from api.utils.settings import settings +from api.utils.session_helpers import get_session_schema_data from api.v1.models import User from api.v1.schemas.user import Token, UserEmailSender from api.v1.schemas.user import ( @@ -31,7 +34,7 @@ UserData2, ) from api.v1.schemas.token import TokenRequest - +from api.v1.schemas.session import SessionCreate from api.v1.schemas.user import (MagicLinkRequest, ChangePasswordSchema, AuthMeResponse) @@ -50,6 +53,7 @@ ) from api.v1.services.totp import totp_service from api.utils.settings import settings +from api.v1.services.session import SessionService auth = APIRouter(prefix="/auth", tags=["Authentication"]) @@ -94,7 +98,15 @@ def register( access_token = user_service.create_access_token(user_id=user.id) refresh_token = user_service.create_refresh_token(user_id=user.id) cta_link = f"{settings.ANCHOR_PYTHON_BASE_URL}/about-us" - + expires = dt.datetime.now(dt.timezone.utc) + (dt.timedelta( + days=settings.JWT_REFRESH_EXPIRY) - dt.timedelta(seconds=1) + ) + session_schema: SessionCreate = get_session_schema_data( + request, + refresh_token=refresh_token, + expires_at=str(expires)) + session_service = SessionService(db) + session_service.create(db=db, schema=session_schema, user_id=user.id) # Send email in the background background_tasks.add_task( @@ -245,6 +257,15 @@ def login(request: Request, login_request: LoginRequest, background_tasks: Backg # Generate access and refresh tokens access_token = user_service.create_access_token(user_id=user.id) refresh_token = user_service.create_refresh_token(user_id=user.id) + expires = dt.datetime.now(dt.timezone.utc) + (dt.timedelta( + days=settings.JWT_REFRESH_EXPIRY) - dt.timedelta(seconds=1) + ) + session_schema: SessionCreate = get_session_schema_data( + request, + refresh_token=refresh_token, + expires_at=str(expires)) + session_service = SessionService(db) + session_service.create(db=db, schema=session_schema, user_id=user.id) # Background task for email notification logger.info(f"Queueing login notification for {user.email} in the background...") From 9ba7e4a082c7cdedfeccccc1f7a64a01987bfdfa Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:37:10 +0100 Subject: [PATCH 03/29] feat: create endpoints for user session --- api/v1/routes/session.py | 79 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 api/v1/routes/session.py diff --git a/api/v1/routes/session.py b/api/v1/routes/session.py new file mode 100644 index 000000000..8ac55f34f --- /dev/null +++ b/api/v1/routes/session.py @@ -0,0 +1,79 @@ +from sqlalchemy.orm import Session + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse + +from api.v1.models import User +from api.db.database import get_db +from api.utils.success_response import success_response +from api.v1.services.user import user_service +from api.v1.services.session import SessionService + + +session_router = APIRouter(prefix="/sessions", tags=["sessions"]) + +@session_router.get("/", response_model=success_response) +def get_all_sessions( + db: Session = Depends(get_db), + current_user: User = Depends(user_service.get_current_user), +): + """ + Endpoint to get all sessions. + + args: + - db: the database session + - current_user: current authenticated user + """ + session_service = SessionService(db) + sessions = session_service.fetch_all(current_user.id) + return success_response( + status_code=status.HTTP_200_OK, + message="Sessions retrieved successfully", + data=jsonable_encoder(sessions, exclude={"refresh_token"}) + ) + +@session_router.get('/{session_id}', response_model=success_response) +def get_session( + session_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(user_service.get_current_user) +): + session_service = SessionService(db) + session = session_service.fetch( + user_id=current_user.id, + session_id=session_id + ) + return success_response( + status_code=status.HTTP_200_OK, + message="Session retrived successfully", + data=jsonable_encoder(session, exclude={"refresh_token"}) + ) + +@session_router.delete('/{session_id}', status_code=status.HTTP_204_NO_CONTENT) +def delete_session( + session_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(user_service.get_current_user) +): + """ + Endpoint to delete a session. + + args: + - session_id (str): ID of the session + - db: the database session + - current_user: current authenticated user + """ + session_service = SessionService(db) + return session_service.delete( + user_id=current_user.id, + session_id=session_id + ) + +@session_router.delete('/', status_code=status.HTTP_204_NO_CONTENT) +def delete_all_sessions( + db: Session = Depends(get_db), + current_user: User = Depends(user_service.get_current_user) +): + session_service = SessionService(db) + return session_service.delete_all(current_user.id) \ No newline at end of file From 4cd81ebd9adf77db5964000b009ed17a91d2abd7 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:50:28 +0100 Subject: [PATCH 04/29] feat: create session service for the UserSession model --- api/v1/services/session.py | 99 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 api/v1/services/session.py diff --git a/api/v1/services/session.py b/api/v1/services/session.py new file mode 100644 index 000000000..76c98b136 --- /dev/null +++ b/api/v1/services/session.py @@ -0,0 +1,99 @@ +from datetime import datetime, timezone +from sqlalchemy.orm import Session + +from fastapi import HTTPException + +from api.v1.schemas.session import SessionCreate +from api.v1.models.session import UserSession + +class SessionService: + """Session service functionality.""" + + def __init__(self, db: Session): + self.db = db + + def is_revoked_or_expired(self, refresh_token: str): + """Check if a session (refresh token) is revoked.""" + session = self.db.query(UserSession).filter(UserSession.refresh_token == refresh_token).first() + if not session: + return True + current_time = datetime.now(timezone.utc) + if session.is_revoked or (session.expires_at < current_time): + return True + return False + + def revoke_sessions(self, sessions): + """Revoke sessions associated with IP and user-agent.""" + try: + for session in sessions: + session.is_revoked = True + self.db.commit() + self.db.refresh(session) + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=400, detail="Could not update session" + ) + + def fetch_by_ip_and_user_agent(self, ip_address: str, user_agent: str): + """Fetch sessions by IP address and user agent.""" + sessions = self.db.query(UserSession).filter( + UserSession.ip_address == ip_address, + UserSession.device == user_agent, + UserSession.is_revoked == False + ).all() + return sessions + + def create(self, db: Session, schema: SessionCreate, user_id: str): + """Create a new session.""" + sessions = self.fetch_by_ip_and_user_agent(schema.ip_address, schema.device) + if sessions: + print(sessions) + self.revoke_sessions(sessions) + new_session = UserSession(**schema.model_dump(), user_id=user_id) + db.add(new_session) + db.commit() + db.refresh(new_session) + return new_session + + def fetch_all(self, user_id): + """Fetch all active sessions.""" + sessions = self.db.query(UserSession).filter(UserSession.is_revoked == False, UserSession.user_id == user_id).all() + return sessions + + def fetch(self, user_id, session_id): + """Fetch a session by its ID.""" + session = self.db.query(UserSession).filter(UserSession.id == session_id, UserSession.user_id == user_id).first() + if not session: + raise HTTPException(status_code=404, detail="Session not found") + return session + + def delete(self, user_id, session_id): + """Delete a session.""" + session = self.fetch(user_id, session_id) + if not session: + raise HTTPException( + status_code=404, detail="Session not found" + ) + try: + self.db.delete(session) + self.db.commit() + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=500, detail="Could not delete session" + ) + + def delete_all(self, user_id): + """Delete all sessions associated to a user""" + sessions = self.fetch_all(user_id) + try: + for session in sessions: + self.db.delete(session) + self.db.commit() + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=500, + detail="Could not delete sessions" + ) \ No newline at end of file From db7ed21ac4a30cd5940e7de8cdc575d5dade93f2 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:51:40 +0100 Subject: [PATCH 05/29] feat: add schema UserCreate for the UserSession --- api/v1/schemas/session.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 api/v1/schemas/session.py diff --git a/api/v1/schemas/session.py b/api/v1/schemas/session.py new file mode 100644 index 000000000..6f2a1121a --- /dev/null +++ b/api/v1/schemas/session.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel, Field + +class SessionCreate(BaseModel): + ip_address: str + location: str = None + device: str = None + is_revoked: bool = False + refresh_token: str + expires_at: str \ No newline at end of file From bd2425c106e5976d187dd8994eea61c22d323988 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:53:40 +0100 Subject: [PATCH 06/29] feat: add session router to api_version_one --- api/v1/routes/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/v1/routes/__init__.py b/api/v1/routes/__init__.py index fdd167d85..1fb64b9fe 100644 --- a/api/v1/routes/__init__.py +++ b/api/v1/routes/__init__.py @@ -45,6 +45,7 @@ from api.v1.routes.privacy import privacies from api.v1.routes.settings import settings from api.v1.routes.terms_and_conditions import terms_and_conditions +from api.v1.routes.session import session_router from api.v1.routes.stripe import subscription_ from api.v1.routes.wishlist import wishlist @@ -98,3 +99,4 @@ api_version_one.include_router(product_comment) api_version_one.include_router(subscription_) api_version_one.include_router(wishlist) +api_version_one.include_router(session_router) From 4a430286dde9f43c272d243e3ccf1c0f1216c06e Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:55:43 +0100 Subject: [PATCH 07/29] update: check if refresh token is revoked in verify_refresh_token method --- api/v1/services/user.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/api/v1/services/user.py b/api/v1/services/user.py index 48c665c80..c1bbb0058 100644 --- a/api/v1/services/user.py +++ b/api/v1/services/user.py @@ -25,6 +25,8 @@ from api.v1.schemas import token from api.v1.services.notification_settings import notification_setting_service from api.v1.services.newsletter import NewsletterService, EmailSchema +from api.v1.services.session import SessionService + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -413,9 +415,14 @@ def verify_access_token(self, access_token: str, credentials_exception): return token_data def verify_refresh_token(self, refresh_token: str, credentials_exception): - """Funtcion to decode and verify refresh token""" + """Function to decode and verify refresh token""" try: + db: Session = next(get_db()) + session_service = SessionService(db) + is_revoked = session_service.is_revoked_or_expired(refresh_token) + if is_revoked: + raise credentials_exception payload = jwt.decode( refresh_token, settings.SECRET_KEY, From 034744d6696125b7a876eec46c1dbec4589a4d85 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:57:07 +0100 Subject: [PATCH 08/29] feat: helper functions to build data for the user session --- api/utils/session_helpers.py | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 api/utils/session_helpers.py diff --git a/api/utils/session_helpers.py b/api/utils/session_helpers.py new file mode 100644 index 000000000..88b5df242 --- /dev/null +++ b/api/utils/session_helpers.py @@ -0,0 +1,46 @@ +import requests + +from fastapi import Request + +from api.v1.schemas.session import SessionCreate +from api.utils.client_helpers import get_ip_address + + +def get_ip_location(ip): + """ Get IP location. + + Args: + ip (str): IP address + """ + country = region = "Unknown" + + try: + response = requests.get(f"https://ipinfo.io/{ip}/json/", timeout=10) + if response.status_code != 200: + return f"{region}, {country}" + data = response.json() + region = data.get("region", "Unknown") + country = data.get("country", "Unknown") + except Exception as e: + return f"{region}, {country}" + return f"{region}, {country}" + + +def get_session_schema_data(request: Request, refresh_token: str = "", expires_at: str = ""): + """Get session schema data. + + Args: + request (Request): Request object + refresh_token (str): Refresh token + expires_at (str): Expiry date + """ + ip = get_ip_address(request) + user_agent= request.headers.get("User-Agent") + return SessionCreate( + ip_address=ip, + location=get_ip_location(ip), + device=user_agent, + is_revoked=False, + refresh_token=refresh_token, + expires_at=expires_at + ) From e130c43dc5a3807ebfcd432095e8dd04b0a0fcdd Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 17:58:13 +0100 Subject: [PATCH 09/29] feat: add UserSession model --- api/v1/models/session.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 api/v1/models/session.py diff --git a/api/v1/models/session.py b/api/v1/models/session.py new file mode 100644 index 000000000..8772ec6e2 --- /dev/null +++ b/api/v1/models/session.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +"""The Session Model.""" + +from api.v1.models.base_model import BaseTableModel +from sqlalchemy import Column, String, Text, ForeignKey, Boolean, text +from sqlalchemy.orm import relationship + + +class UserSession(BaseTableModel): + __tablename__ = "sessions" + + user_id = Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + ip_address = Column(String, nullable=False) + location = Column(String, nullable=True) + device = Column(String, nullable=True) + is_revoked = Column(Boolean, server_default=text("false")) + refresh_token = Column(String, nullable=False) + expires_at = Column(String, nullable=False) + + user = relationship("User", back_populates="sessions") + + def __str__(self): + return f"{self.user_id} - {self.ip_address}" \ No newline at end of file From cb794ddf65ddc9260e25fdbf2cffa865dd86d547 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 22:28:39 +0100 Subject: [PATCH 10/29] update: change expires_at field to Datetime --- api/v1/models/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/v1/models/session.py b/api/v1/models/session.py index 8772ec6e2..22c185088 100644 --- a/api/v1/models/session.py +++ b/api/v1/models/session.py @@ -2,7 +2,7 @@ """The Session Model.""" from api.v1.models.base_model import BaseTableModel -from sqlalchemy import Column, String, Text, ForeignKey, Boolean, text +from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean, text from sqlalchemy.orm import relationship @@ -15,7 +15,7 @@ class UserSession(BaseTableModel): device = Column(String, nullable=True) is_revoked = Column(Boolean, server_default=text("false")) refresh_token = Column(String, nullable=False) - expires_at = Column(String, nullable=False) + expires_at = Column(DateTime, nullable=False) user = relationship("User", back_populates="sessions") From 3cc09fa3b2404b9715eec5b40ca8f8e9da013070 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 22:32:34 +0100 Subject: [PATCH 11/29] fix: remove typecast to str in get_session_schema_data when called --- api/v1/routes/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py index 1aa78a1b1..56fe37f89 100644 --- a/api/v1/routes/auth.py +++ b/api/v1/routes/auth.py @@ -104,7 +104,7 @@ def register( session_schema: SessionCreate = get_session_schema_data( request, refresh_token=refresh_token, - expires_at=str(expires)) + expires_at=expires) session_service = SessionService(db) session_service.create(db=db, schema=session_schema, user_id=user.id) @@ -263,7 +263,7 @@ def login(request: Request, login_request: LoginRequest, background_tasks: Backg session_schema: SessionCreate = get_session_schema_data( request, refresh_token=refresh_token, - expires_at=str(expires)) + expires_at=expires) session_service = SessionService(db) session_service.create(db=db, schema=session_schema, user_id=user.id) From 96c443544469e609ffb380a4fcfcebafcb12e450 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 22:34:14 +0100 Subject: [PATCH 12/29] update: change expires_at in schema to datetime --- api/v1/schemas/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/v1/schemas/session.py b/api/v1/schemas/session.py index 6f2a1121a..99387e65f 100644 --- a/api/v1/schemas/session.py +++ b/api/v1/schemas/session.py @@ -1,3 +1,4 @@ +from datetime import datetime from pydantic import BaseModel, Field class SessionCreate(BaseModel): @@ -6,4 +7,4 @@ class SessionCreate(BaseModel): device: str = None is_revoked: bool = False refresh_token: str - expires_at: str \ No newline at end of file + expires_at: datetime \ No newline at end of file From 74eab92124f9999ead3191d32ead51d4dde67711 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 22:34:54 +0100 Subject: [PATCH 13/29] fix: format datetime before comparison --- api/v1/services/session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/v1/services/session.py b/api/v1/services/session.py index 76c98b136..d88e2ccb5 100644 --- a/api/v1/services/session.py +++ b/api/v1/services/session.py @@ -17,6 +17,8 @@ def is_revoked_or_expired(self, refresh_token: str): session = self.db.query(UserSession).filter(UserSession.refresh_token == refresh_token).first() if not session: return True + if isinstance(session.expires_at, str): + session.expires_at = datetime.fromisoformat(session.expires_at) current_time = datetime.now(timezone.utc) if session.is_revoked or (session.expires_at < current_time): return True From 7dd152d00422e1867cb848a2bebc7d8058b8e27d Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 22:35:43 +0100 Subject: [PATCH 14/29] feat: add background task to clear sessions table in db to celery --- api/utils/celery_config.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 api/utils/celery_config.py diff --git a/api/utils/celery_config.py b/api/utils/celery_config.py new file mode 100644 index 000000000..9fd19b358 --- /dev/null +++ b/api/utils/celery_config.py @@ -0,0 +1,35 @@ +from celery import Celery +from celery.schedules import crontab +from datetime import datetime, timezone + +from api.utils.settings import settings + +celery_app = Celery('api', broker='redis://localhost:6379/0', backend='redis://localhost:6379/0') + +celery_app.conf.beat_schedule = { + "clean_db_every_day": { + "task": "api.utils.celery_config.clean_expired_and_revoked_tokens_from_sessions_table", + "schedule": crontab(day_of_week=0, hour=0, minute=0)# run every sunday midnight + } +} + +celery_app.conf.timezone = "UTC" + +@celery_app.task +def clean_expired_and_revoked_tokens_from_sessions_table(): + """Clean expired and revoked tokens from the sessions table.""" + from api.db.database import get_db + from api.v1.models.session import UserSession + from sqlalchemy import cast, DateTime + + db = next(get_db()) + try: + current_time = datetime.now(timezone.utc) + db.query(UserSession).filter(UserSession.is_revoked == True).delete() + db.query(UserSession).filter(cast(UserSession.expires_at, DateTime) < current_time).delete() + db.commit() + except Exception as e: + db.rollback() + raise e + finally: + db.close() From a806a04d346f0eb2220787009c603438eebdf194 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 22:37:33 +0100 Subject: [PATCH 15/29] chore: add depencies for user session management --- requirements.txt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/requirements.txt b/requirements.txt index 12775597c..9f578fa61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ aiohttp-retry==2.8.3 aiosignal==1.3.1 aiosmtplib==2.0.2 alembic==1.13.2 +amqp==5.3.1 annotated-types==0.7.0 anyio==4.4.0 astroid==3.2.4 @@ -11,15 +12,20 @@ attrs==23.2.0 Authlib==1.3.1 autopep8==2.3.1 bcrypt==4.1.3 +billiard==4.2.1 black==24.4.2 bleach==6.1.0 blinker==1.8.2 cachetools==5.4.0 +celery==5.4.0 certifi==2024.7.4 cffi==1.16.0 cfgv==3.4.0 charset-normalizer==3.3.2 click==8.1.7 +click-didyoumean==0.3.1 +click-plugins==1.1.1 +click-repl==0.3.0 colorama==0.4.6 cryptography==43.0.0 cssselect==1.2.0 @@ -52,6 +58,7 @@ iniconfig==2.0.0 isort==5.13.2 itsdangerous==2.2.0 Jinja2==3.1.4 +kombu==5.4.2 limits==3.13.0 lxml==5.2.2 Mako==1.3.5 @@ -73,6 +80,7 @@ platformdirs==4.2.2 pluggy==1.5.0 pre-commit==3.7.1 premailer==3.10.0 +prompt_toolkit==3.0.50 psycopg2-binary==2.9.9 pyasn1==0.6.0 pycodestyle==2.12.0 @@ -96,7 +104,11 @@ python-jose==3.3.0 python-multipart==0.0.9 pytz==2024.1 PyYAML==6.0.1 +<<<<<<< HEAD qrcode==8.0 +======= +redis==5.2.1 +>>>>>>> ef17770a (chore: add depencies for user session management) requests==2.32.3 rich==13.7.1 rsa==4.9 @@ -112,6 +124,7 @@ tomlkit==0.13.0 twilio==9.2.3 typer==0.12.3 typing_extensions==4.12.2 +tzdata==2025.1 ua-parser==1.0.1 ua-parser-builtins==0.18.0.post1 urllib3==2.2.2 @@ -119,8 +132,10 @@ user-agents==2.2.0 uuid7==0.1.0 uvicorn==0.30.3 uvloop==0.19.0 +vine==5.1.0 virtualenv==20.26.3 watchfiles==0.22.0 +wcwidth==0.2.13 webencodings==0.5.1 websockets==12.0 wrapt==1.16.0 From b21aa382ea63a44a00ed9349dff7dea8b135fc83 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Sat, 1 Mar 2025 23:07:45 +0100 Subject: [PATCH 16/29] doc: add steps to run background tasks with celery --- README.md | 240 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 154 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index bd395fb5a..8753f5045 100644 --- a/README.md +++ b/README.md @@ -1,69 +1,73 @@ -# **FASTAPI Boilerplate** -A FastAPI boilerplate for efficient project setup. +# **FASTAPI Boilerplate** -## **Cloning the Repository** +A FastAPI boilerplate for efficient project setup. -1. **Fork the repository** and clone it: +## **Cloning the Repository** + +1. **Fork the repository** and clone it: ```sh git clone https://github.com//hng_boilerplate_python_fastapi_web.git ``` -2. **Navigate into the project directory**: +2. **Navigate into the project directory**: ```sh cd hng_boilerplate_python_fastapi_web ``` -3. **Switch to the development branch** (if not already on `dev`): +3. **Switch to the development branch** (if not already on `dev`): ```sh git checkout dev ``` +## **Setup Instructions** -## **Setup Instructions** - -1. **Create a virtual environment**: +1. **Create a virtual environment**: ```sh python3 -m venv .venv ``` -2. **Activate the virtual environment**: - - On macOS/Linux: +2. **Activate the virtual environment**: + - On macOS/Linux: ```sh source .venv/bin/activate ``` - - On Windows (PowerShell): + - On Windows (PowerShell): ```sh .venv\Scripts\Activate ``` -3. **Install project dependencies**: +3. **Install project dependencies**: ```sh pip install -r requirements.txt ``` -4. **Create a `.env` file** from `.env.sample`: +4. **Create a `.env` file** from `.env.sample`: ```sh cp .env.sample .env ``` -5. **Start the server**: +5. **Start the server**: ```sh python main.py ``` --- -## **Database Setup** +## **Database Setup** -### **Replacing Placeholders in Database Setup** +### **Replacing Placeholders in Database Setup** When setting up the database, you need to replace **placeholders** with your actual values. Below is a breakdown of where to replace them: --- ## **Step 1: Create a Database User** + ```sql CREATE USER user WITH PASSWORD 'your_password'; ``` -🔹 **Replace:** -- `user` → Your **preferred database username** (e.g., `fastapi_user`). -- `'your_password'` → A **secure password** for the user (e.g., `'StrongP@ssw0rd'`). -✅ **Example:** +🔹 **Replace:** + +- `user` → Your **preferred database username** (e.g., `fastapi_user`). +- `'your_password'` → A **secure password** for the user (e.g., `'StrongP@ssw0rd'`). + +✅ **Example:** + ```sql CREATE USER fastapi_user WITH PASSWORD 'StrongP@ssw0rd'; ``` @@ -71,13 +75,17 @@ CREATE USER fastapi_user WITH PASSWORD 'StrongP@ssw0rd'; --- ## **Step 2: Create the Database** + ```sql CREATE DATABASE hng_fast_api; ``` -🔹 **Replace:** -- `hng_fast_api` → Your **preferred database name** (e.g., `fastapi_db`). -✅ **Example:** +🔹 **Replace:** + +- `hng_fast_api` → Your **preferred database name** (e.g., `fastapi_db`). + +✅ **Example:** + ```sql CREATE DATABASE fastapi_db; ``` @@ -85,14 +93,18 @@ CREATE DATABASE fastapi_db; --- ## **Step 3: Grant Permissions** + ```sql GRANT ALL PRIVILEGES ON DATABASE hng_fast_api TO user; ``` -🔹 **Replace:** -- `hng_fast_api` → The **database name you used** in Step 2. -- `user` → The **username you created** in Step 1. -✅ **Example:** +🔹 **Replace:** + +- `hng_fast_api` → The **database name you used** in Step 2. +- `user` → The **username you created** in Step 1. + +✅ **Example:** + ```sql GRANT ALL PRIVILEGES ON DATABASE fastapi_db TO fastapi_user; ``` @@ -100,17 +112,21 @@ GRANT ALL PRIVILEGES ON DATABASE fastapi_db TO fastapi_user; --- ## **Step 4: Update `.env` File** + Edit the `.env` file to match your setup. ```env DATABASE_URL=postgresql://user:your_password@localhost/hng_fast_api ``` -🔹 **Replace:** -- `user` → Your **database username**. -- `your_password` → Your **database password**. -- `hng_fast_api` → Your **database name**. -✅ **Example:** +🔹 **Replace:** + +- `user` → Your **database username**. +- `your_password` → Your **database password**. +- `hng_fast_api` → Your **database name**. + +✅ **Example:** + ```env DATABASE_URL=postgresql://fastapi_user:StrongP@ssw0rd@localhost/fastapi_db ``` @@ -118,59 +134,88 @@ DATABASE_URL=postgresql://fastapi_user:StrongP@ssw0rd@localhost/fastapi_db --- ## **Step 5: Verify Connection** + After setting up the database, test the connection: ```sh psql -U user -d hng_fast_api -h localhost ``` -🔹 **Replace:** -- `user` → Your **database username**. -- `hng_fast_api` → Your **database name**. -✅ **Example:** +🔹 **Replace:** + +- `user` → Your **database username**. +- `hng_fast_api` → Your **database name**. + +✅ **Example:** + ```sh psql -U fastapi_user -d fastapi_db -h localhost ``` -## **Step 6: Run database migrations** - ```sh - alembic upgrade head - ``` - _Do NOT run `alembic revision --autogenerate -m 'initial migration'` initially!_ +## **Step 6: Run database migrations** + +```sh +alembic upgrade head +``` + +_Do NOT run `alembic revision --autogenerate -m 'initial migration'` initially!_ + +## **Step 7: If making changes to database models, update migrations** -## **Step 7: If making changes to database models, update migrations** ```sh alembic revision --autogenerate -m 'your migration message' alembic upgrade head - ``` -## **Step 8: Seed dummy data** - ```sh - python3 seed.py - ``` +``` + +## **Step 8: Seed dummy data** + +```sh +python3 seed.py +``` --- -## **Adding Tables and Columns** +## **Adding Tables and Columns** -1. **After creating new tables or modifying models**: - - Run Alembic migrations: +1. **After creating new tables or modifying models**: + - Run Alembic migrations: ```sh alembic revision --autogenerate -m "Migration message" alembic upgrade head ``` - - Ensure you **import new models** into `api/v1/models/__init__.py`. + - Ensure you **import new models** into `api/v1/models/__init__.py`. - You do NOT need to manually import them in `alembic/env.py`. --- -## **Adding New Routes** +## CELERY SETUP FOR BACKGROUND TASKS + +Celery will regularly remove revoked and expired tokens from the UserSession table in the db. + +**Ensure** `redis-server` is installed and running on your device. + +- Run this to start worker + +```bash +celery -A api.utils.celery_config worker --loglevel=info +``` + +- Run beat scheduler in a new terminal + +```bash +celery -A api.utils.celery_config beat --loglevel=info +``` -1. **Check if a related route file already exists** in `api/v1/routes/`. - - If yes, add your route inside the existing file. - - If no, create a new file following the naming convention. -2. **Define the router** inside the new route file: - - Include the prefix (without `/api/v1` since it's already handled). -3. **Register the router in `api/v1/routes/__init__.py`**: +The scheduler will send tasks to the worker to process based on the interval set. + +## **Adding New Routes** + +1. **Check if a related route file already exists** in `api/v1/routes/`. + - If yes, add your route inside the existing file. + - If no, create a new file following the naming convention. +2. **Define the router** inside the new route file: + - Include the prefix (without `/api/v1` since it's already handled). +3. **Register the router in `api/v1/routes/__init__.py`**: ```python from .new_route import router as new_router api_version_one.include_router(new_router) @@ -178,79 +223,102 @@ psql -U fastapi_user -d fastapi_db -h localhost --- -## **Running Tests with Pytest** +## **Running Tests with Pytest** + +### **Install Pytest** + +Ensure `pytest` is installed in your virtual environment: -### **Install Pytest** -Ensure `pytest` is installed in your virtual environment: ```sh pip install pytest ``` -### **Run all tests in the project** -From the **project root directory**, run: +### **Run all tests in the project** + +From the **project root directory**, run: + ```sh pytest ``` + This will automatically discover and execute all test files in the `tests/` directory. -### **Run tests in a specific directory** -To run tests in a specific model directory (e.g., `tests/v1/user/`): +### **Run tests in a specific directory** + +To run tests in a specific model directory (e.g., `tests/v1/user/`): + ```sh pytest tests/v1/user/ ``` -### **Run a specific test file** -To run tests from a specific test file (e.g., `test_signup.py` inside `tests/v1/auth/`): +### **Run a specific test file** + +To run tests from a specific test file (e.g., `test_signup.py` inside `tests/v1/auth/`): + ```sh pytest tests/v1/auth/test_signup.py ``` -### **Run a specific test function** -If you want to run a specific test inside a file, use: +### **Run a specific test function** + +If you want to run a specific test inside a file, use: + ```sh pytest tests/v1/auth/test_signup.py::test_user_signup ``` -### **Run tests with detailed output** -For verbose output, add the `-v` flag: +### **Run tests with detailed output** + +For verbose output, add the `-v` flag: + ```sh pytest -v ``` -### **Run tests and generate coverage report** -To check test coverage, install `pytest-cov`: +### **Run tests and generate coverage report** + +To check test coverage, install `pytest-cov`: + ```sh pip install pytest-cov ``` -Then run: + +Then run: + ```sh pytest --cov=api ``` --- -## **Common Migration Issues & Solutions** +## **Common Migration Issues & Solutions** + +### **Error: "Target database is not up to date."** + +If you encounter this issue when running: -### **Error: "Target database is not up to date."** -If you encounter this issue when running: ```sh alembic revision --autogenerate -m 'your migration message' ``` -#### **Solution**: -Run the following command first: + +#### **Solution**: + +Run the following command first: + ```sh alembic upgrade head ``` -Then retry: + +Then retry: + ```sh alembic revision --autogenerate -m 'your migration message' ``` --- -## **Contribution Guidelines** - -- **Test your endpoints and models** before pushing changes. -- **Push Alembic migrations** if database models are modified. -- Ensure your code **follows project standards** and **passes tests** before submitting a pull request. +## **Contribution Guidelines** +- **Test your endpoints and models** before pushing changes. +- **Push Alembic migrations** if database models are modified. +- Ensure your code **follows project standards** and **passes tests** before submitting a pull request. From 96d83fa1ff4b3828fddd5f3ed695afb6a84cf3fc Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 00:34:32 +0100 Subject: [PATCH 17/29] fix: create user session asynchronously to prevent latency --- api/utils/session_helpers.py | 66 +++++++++++++++++++++++++----------- api/v1/routes/auth.py | 37 ++++++++++---------- api/v1/services/session.py | 12 ++++--- 3 files changed, 73 insertions(+), 42 deletions(-) diff --git a/api/utils/session_helpers.py b/api/utils/session_helpers.py index 88b5df242..ca167d83e 100644 --- a/api/utils/session_helpers.py +++ b/api/utils/session_helpers.py @@ -1,32 +1,38 @@ -import requests - -from fastapi import Request +import httpx +from fastapi import Request, BackgroundTasks +from sqlalchemy.orm import Session +from api.db.database import get_db from api.v1.schemas.session import SessionCreate +from api.v1.services.session import SessionService from api.utils.client_helpers import get_ip_address -def get_ip_location(ip): +async def get_ip_location(ip): """ Get IP location. Args: ip (str): IP address """ country = region = "Unknown" - + try: - response = requests.get(f"https://ipinfo.io/{ip}/json/", timeout=10) - if response.status_code != 200: - return f"{region}, {country}" - data = response.json() - region = data.get("region", "Unknown") - country = data.get("country", "Unknown") - except Exception as e: - return f"{region}, {country}" + async with httpx.AsyncClient() as client: + response = await client.get(f"https://ipinfo.io/{ip}/json/", timeout=10) + response.raise_for_status() + data = response.json() + region = data.get("region", "Unknown") + country = data.get("country", "Unknown") + except httpx.RequestError as exc: + print(f"An error occurred while requesting {exc.request.url!r}.") + except httpx.HTTPStatusError as exc: + print(f"Error response {exc.response.status_code} while requesting {exc.request.url!r}.") + except Exception as exc: + print(f"An unexpected error occurred: {exc}") + return f"{region}, {country}" - - -def get_session_schema_data(request: Request, refresh_token: str = "", expires_at: str = ""): + +async def get_session_schema_data(request: Request, refresh_token: str = "", expires_at: str = ""): """Get session schema data. Args: @@ -35,12 +41,34 @@ def get_session_schema_data(request: Request, refresh_token: str = "", expires_a expires_at (str): Expiry date """ ip = get_ip_address(request) - user_agent= request.headers.get("User-Agent") return SessionCreate( ip_address=ip, - location=get_ip_location(ip), - device=user_agent, + location=await get_ip_location(ip), + device=request.headers.get("User-Agent"), is_revoked=False, refresh_token=refresh_token, expires_at=expires_at ) + +async def create_session_for_user( + request: Request, + user_id: str, + refresh_token: str = "", + expires_at: str = "", + ): + """Create session for user. + + Args: + request (Request): Request object + db: Database session + refresh_token (str): Refresh token + expires_at (str): Expiry date + """ + session_data: SessionCreate = await get_session_schema_data( + request, + refresh_token=refresh_token, + expires_at=expires_at, + ) + db = next(get_db()) + session_service = SessionService(db) + session_service.create(schema=session_data, user_id=user_id) diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py index 56fe37f89..51979565a 100644 --- a/api/v1/routes/auth.py +++ b/api/v1/routes/auth.py @@ -23,7 +23,7 @@ from api.utils.success_response import auth_response, success_response from api.utils.send_mail import send_magic_link from api.utils.settings import settings -from api.utils.session_helpers import get_session_schema_data +from api.utils.session_helpers import create_session_for_user from api.v1.models import User from api.v1.schemas.user import Token, UserEmailSender from api.v1.schemas.user import ( @@ -34,7 +34,7 @@ UserData2, ) from api.v1.schemas.token import TokenRequest -from api.v1.schemas.session import SessionCreate +# from api.v1.schemas.session import SessionCreate from api.v1.schemas.user import (MagicLinkRequest, ChangePasswordSchema, AuthMeResponse) @@ -53,7 +53,7 @@ ) from api.v1.services.totp import totp_service from api.utils.settings import settings -from api.v1.services.session import SessionService +# from api.v1.services.session import SessionService auth = APIRouter(prefix="/auth", tags=["Authentication"]) @@ -79,14 +79,9 @@ def register( # Create user account user = user_service.create(db=db, schema=user_schema) - verification_token = user_service.create_verification_token(user.id) verification_link = f"{base_url}/api/v1/auth/verify-email?token={verification_token}" - access_token = user_service.create_access_token(user_id=user.id) - refresh_token = user_service.create_refresh_token(user_id=user.id) - cta_link = "https://anchor-python.teams.hng.tech/about-us" - # create an organization for the user org = CreateUpdateOrganisation( name=f"{user.email}'s Organisation", email=user.email @@ -98,15 +93,18 @@ def register( access_token = user_service.create_access_token(user_id=user.id) refresh_token = user_service.create_refresh_token(user_id=user.id) cta_link = f"{settings.ANCHOR_PYTHON_BASE_URL}/about-us" + + # create session for user expires = dt.datetime.now(dt.timezone.utc) + (dt.timedelta( days=settings.JWT_REFRESH_EXPIRY) - dt.timedelta(seconds=1) ) - session_schema: SessionCreate = get_session_schema_data( - request, + background_tasks.add_task( + create_session_for_user, + request=request, + user_id=user.id, refresh_token=refresh_token, - expires_at=expires) - session_service = SessionService(db) - session_service.create(db=db, schema=session_schema, user_id=user.id) + expires_at=expires + ) # Send email in the background background_tasks.add_task( @@ -257,15 +255,18 @@ def login(request: Request, login_request: LoginRequest, background_tasks: Backg # Generate access and refresh tokens access_token = user_service.create_access_token(user_id=user.id) refresh_token = user_service.create_refresh_token(user_id=user.id) + + # create session for user expires = dt.datetime.now(dt.timezone.utc) + (dt.timedelta( days=settings.JWT_REFRESH_EXPIRY) - dt.timedelta(seconds=1) ) - session_schema: SessionCreate = get_session_schema_data( - request, + background_tasks.add_task( + create_session_for_user, + request=request, + user_id=user.id, refresh_token=refresh_token, - expires_at=expires) - session_service = SessionService(db) - session_service.create(db=db, schema=session_schema, user_id=user.id) + expires_at=expires + ) # Background task for email notification logger.info(f"Queueing login notification for {user.email} in the background...") diff --git a/api/v1/services/session.py b/api/v1/services/session.py index d88e2ccb5..dd7b7fa3e 100644 --- a/api/v1/services/session.py +++ b/api/v1/services/session.py @@ -10,6 +10,7 @@ class SessionService: """Session service functionality.""" def __init__(self, db: Session): + """Initialize the service.""" self.db = db def is_revoked_or_expired(self, refresh_token: str): @@ -19,6 +20,8 @@ def is_revoked_or_expired(self, refresh_token: str): return True if isinstance(session.expires_at, str): session.expires_at = datetime.fromisoformat(session.expires_at) + + session.expires_at = session.expires_at.astimezone(timezone.utc) current_time = datetime.now(timezone.utc) if session.is_revoked or (session.expires_at < current_time): return True @@ -46,16 +49,15 @@ def fetch_by_ip_and_user_agent(self, ip_address: str, user_agent: str): ).all() return sessions - def create(self, db: Session, schema: SessionCreate, user_id: str): + def create(self, schema: SessionCreate, user_id: str): """Create a new session.""" sessions = self.fetch_by_ip_and_user_agent(schema.ip_address, schema.device) if sessions: - print(sessions) self.revoke_sessions(sessions) new_session = UserSession(**schema.model_dump(), user_id=user_id) - db.add(new_session) - db.commit() - db.refresh(new_session) + self.db.add(new_session) + self.db.commit() + self.db.refresh(new_session) return new_session def fetch_all(self, user_id): From 705e228375cd06b444b1839e73dc6e62c163adbc Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 01:08:56 +0100 Subject: [PATCH 18/29] feat: delete user session when user logout --- api/v1/routes/auth.py | 8 ++++++-- api/v1/services/session.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py index 51979565a..d2965110f 100644 --- a/api/v1/routes/auth.py +++ b/api/v1/routes/auth.py @@ -34,7 +34,6 @@ UserData2, ) from api.v1.schemas.token import TokenRequest -# from api.v1.schemas.session import SessionCreate from api.v1.schemas.user import (MagicLinkRequest, ChangePasswordSchema, AuthMeResponse) @@ -53,7 +52,7 @@ ) from api.v1.services.totp import totp_service from api.utils.settings import settings -# from api.v1.services.session import SessionService +from api.v1.services.session import SessionService auth = APIRouter(prefix="/auth", tags=["Authentication"]) @@ -307,6 +306,11 @@ def logout( ): """Endpoint to log a user out of their account""" + # logout/delete current user session + current_refresh_token = request.cookies.get("refresh_token") + SessionService.logout_session(db, current_user.id, current_refresh_token) + + response = success_response(status_code=200, message="User logged put successfully") # Delete refresh token from cookies diff --git a/api/v1/services/session.py b/api/v1/services/session.py index dd7b7fa3e..fbea846d6 100644 --- a/api/v1/services/session.py +++ b/api/v1/services/session.py @@ -13,6 +13,16 @@ def __init__(self, db: Session): """Initialize the service.""" self.db = db + @staticmethod + def logout_session(db: Session, user_id: str, refresh_token: str): + """Logout a session.""" + session = db.query(UserSession).filter( + UserSession.refresh_token == refresh_token, UserSession.user_id == user_id).first() + if not session: + return + db.delete(session) + db.commit() + def is_revoked_or_expired(self, refresh_token: str): """Check if a session (refresh token) is revoked.""" session = self.db.query(UserSession).filter(UserSession.refresh_token == refresh_token).first() From d8f4f0f3a68631d89959e53b05716fece6d769c2 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 01:37:06 +0100 Subject: [PATCH 19/29] update: revoke sessions when delete and delete_all methods are called --- api/v1/services/session.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/api/v1/services/session.py b/api/v1/services/session.py index fbea846d6..d758a9722 100644 --- a/api/v1/services/session.py +++ b/api/v1/services/session.py @@ -47,7 +47,7 @@ def revoke_sessions(self, sessions): except Exception as e: self.db.rollback() raise HTTPException( - status_code=400, detail="Could not update session" + status_code=400, detail="Could not revoke session(s)" ) def fetch_by_ip_and_user_agent(self, ip_address: str, user_agent: str): @@ -89,25 +89,9 @@ def delete(self, user_id, session_id): raise HTTPException( status_code=404, detail="Session not found" ) - try: - self.db.delete(session) - self.db.commit() - except Exception as e: - self.db.rollback() - raise HTTPException( - status_code=500, detail="Could not delete session" - ) - + self.revoke_sessions([session]) + def delete_all(self, user_id): - """Delete all sessions associated to a user""" + """Revoke all sessions associated to a user""" sessions = self.fetch_all(user_id) - try: - for session in sessions: - self.db.delete(session) - self.db.commit() - except Exception as e: - self.db.rollback() - raise HTTPException( - status_code=500, - detail="Could not delete sessions" - ) \ No newline at end of file + self.revoke_sessions(sessions) From 2940aa4edc0921bedf036f3a161ecc40d3bfac10 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 01:45:08 +0100 Subject: [PATCH 20/29] update: make celery run task every day by midnight --- api/utils/celery_config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/api/utils/celery_config.py b/api/utils/celery_config.py index 9fd19b358..f1f3998a0 100644 --- a/api/utils/celery_config.py +++ b/api/utils/celery_config.py @@ -9,12 +9,23 @@ celery_app.conf.beat_schedule = { "clean_db_every_day": { "task": "api.utils.celery_config.clean_expired_and_revoked_tokens_from_sessions_table", - "schedule": crontab(day_of_week=0, hour=0, minute=0)# run every sunday midnight + "schedule": crontab(hour=0, minute=0) # run everyday at midnight } } celery_app.conf.timezone = "UTC" +# Additional configurations +celery_app.conf.update( + task_serializer='json', + result_serializer='json', + accept_content=['json'], + task_acks_late=True, + worker_prefetch_multiplier=1, + task_time_limit=300, + task_soft_time_limit=180, +) + @celery_app.task def clean_expired_and_revoked_tokens_from_sessions_table(): """Clean expired and revoked tokens from the sessions table.""" From 2c79c94c99b586a3c557b7d843745d2eb5085ccc Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 07:25:11 +0100 Subject: [PATCH 21/29] update: make session service receive db arg for all its methods --- api/utils/session_helpers.py | 7 ++--- api/v1/routes/auth.py | 6 ++-- api/v1/routes/session.py | 20 +++--------- api/v1/services/session.py | 60 ++++++++++++++++++------------------ api/v1/services/user.py | 5 ++- 5 files changed, 44 insertions(+), 54 deletions(-) diff --git a/api/utils/session_helpers.py b/api/utils/session_helpers.py index ca167d83e..d804c264c 100644 --- a/api/utils/session_helpers.py +++ b/api/utils/session_helpers.py @@ -4,7 +4,7 @@ from api.db.database import get_db from api.v1.schemas.session import SessionCreate -from api.v1.services.session import SessionService +from api.v1.services.session import session_service from api.utils.client_helpers import get_ip_address @@ -51,6 +51,7 @@ async def get_session_schema_data(request: Request, refresh_token: str = "", exp ) async def create_session_for_user( + db: Session, request: Request, user_id: str, refresh_token: str = "", @@ -69,6 +70,4 @@ async def create_session_for_user( refresh_token=refresh_token, expires_at=expires_at, ) - db = next(get_db()) - session_service = SessionService(db) - session_service.create(schema=session_data, user_id=user_id) + session_service.create(db, schema=session_data, user_id=user_id) diff --git a/api/v1/routes/auth.py b/api/v1/routes/auth.py index d2965110f..32b6763db 100644 --- a/api/v1/routes/auth.py +++ b/api/v1/routes/auth.py @@ -52,7 +52,7 @@ ) from api.v1.services.totp import totp_service from api.utils.settings import settings -from api.v1.services.session import SessionService +from api.v1.services.session import session_service auth = APIRouter(prefix="/auth", tags=["Authentication"]) @@ -99,6 +99,7 @@ def register( ) background_tasks.add_task( create_session_for_user, + db=db, request=request, user_id=user.id, refresh_token=refresh_token, @@ -261,6 +262,7 @@ def login(request: Request, login_request: LoginRequest, background_tasks: Backg ) background_tasks.add_task( create_session_for_user, + db=db, request=request, user_id=user.id, refresh_token=refresh_token, @@ -308,7 +310,7 @@ def logout( # logout/delete current user session current_refresh_token = request.cookies.get("refresh_token") - SessionService.logout_session(db, current_user.id, current_refresh_token) + session_service.logout_session(db, current_user.id, current_refresh_token) response = success_response(status_code=200, message="User logged put successfully") diff --git a/api/v1/routes/session.py b/api/v1/routes/session.py index 8ac55f34f..b17b74d0f 100644 --- a/api/v1/routes/session.py +++ b/api/v1/routes/session.py @@ -8,7 +8,7 @@ from api.db.database import get_db from api.utils.success_response import success_response from api.v1.services.user import user_service -from api.v1.services.session import SessionService +from api.v1.services.session import session_service session_router = APIRouter(prefix="/sessions", tags=["sessions"]) @@ -25,8 +25,7 @@ def get_all_sessions( - db: the database session - current_user: current authenticated user """ - session_service = SessionService(db) - sessions = session_service.fetch_all(current_user.id) + sessions = session_service.fetch_all(db, current_user.id) return success_response( status_code=status.HTTP_200_OK, message="Sessions retrieved successfully", @@ -39,11 +38,7 @@ def get_session( db: Session = Depends(get_db), current_user: User = Depends(user_service.get_current_user) ): - session_service = SessionService(db) - session = session_service.fetch( - user_id=current_user.id, - session_id=session_id - ) + session = session_service.fetch(db, current_user.id, session_id) return success_response( status_code=status.HTTP_200_OK, message="Session retrived successfully", @@ -64,16 +59,11 @@ def delete_session( - db: the database session - current_user: current authenticated user """ - session_service = SessionService(db) - return session_service.delete( - user_id=current_user.id, - session_id=session_id - ) + return session_service.delete(db, current_user.id, session_id) @session_router.delete('/', status_code=status.HTTP_204_NO_CONTENT) def delete_all_sessions( db: Session = Depends(get_db), current_user: User = Depends(user_service.get_current_user) ): - session_service = SessionService(db) - return session_service.delete_all(current_user.id) \ No newline at end of file + return session_service.delete_all(db, current_user.id) \ No newline at end of file diff --git a/api/v1/services/session.py b/api/v1/services/session.py index d758a9722..e7fbac0e9 100644 --- a/api/v1/services/session.py +++ b/api/v1/services/session.py @@ -1,20 +1,17 @@ from datetime import datetime, timezone from sqlalchemy.orm import Session +from typing import List from fastapi import HTTPException +from api.db.database import get_db from api.v1.schemas.session import SessionCreate from api.v1.models.session import UserSession class SessionService: """Session service functionality.""" - def __init__(self, db: Session): - """Initialize the service.""" - self.db = db - - @staticmethod - def logout_session(db: Session, user_id: str, refresh_token: str): + def logout_session(self, db: Session, user_id: str, refresh_token: str): """Logout a session.""" session = db.query(UserSession).filter( UserSession.refresh_token == refresh_token, UserSession.user_id == user_id).first() @@ -23,9 +20,9 @@ def logout_session(db: Session, user_id: str, refresh_token: str): db.delete(session) db.commit() - def is_revoked_or_expired(self, refresh_token: str): + def is_revoked_or_expired(self, db: Session, refresh_token: str) -> bool: """Check if a session (refresh token) is revoked.""" - session = self.db.query(UserSession).filter(UserSession.refresh_token == refresh_token).first() + session = db.query(UserSession).filter(UserSession.refresh_token == refresh_token).first() if not session: return True if isinstance(session.expires_at, str): @@ -37,61 +34,64 @@ def is_revoked_or_expired(self, refresh_token: str): return True return False - def revoke_sessions(self, sessions): + def revoke_sessions(self, db: Session, sessions: List[UserSession]): """Revoke sessions associated with IP and user-agent.""" try: for session in sessions: session.is_revoked = True - self.db.commit() - self.db.refresh(session) + db.commit() + db.refresh(session) except Exception as e: - self.db.rollback() + db.rollback() raise HTTPException( status_code=400, detail="Could not revoke session(s)" ) - def fetch_by_ip_and_user_agent(self, ip_address: str, user_agent: str): + def fetch_by_ip_and_user_agent(self, db: Session, ip_address: str, user_agent: str) -> List[UserSession]: """Fetch sessions by IP address and user agent.""" - sessions = self.db.query(UserSession).filter( + sessions = db.query(UserSession).filter( UserSession.ip_address == ip_address, UserSession.device == user_agent, UserSession.is_revoked == False ).all() return sessions - def create(self, schema: SessionCreate, user_id: str): + def create(self, db: Session, schema: SessionCreate, user_id: str) -> UserSession: """Create a new session.""" - sessions = self.fetch_by_ip_and_user_agent(schema.ip_address, schema.device) + sessions = self.fetch_by_ip_and_user_agent(db, schema.ip_address, schema.device) if sessions: - self.revoke_sessions(sessions) + self.revoke_sessions(db, sessions) new_session = UserSession(**schema.model_dump(), user_id=user_id) - self.db.add(new_session) - self.db.commit() - self.db.refresh(new_session) + db.add(new_session) + db.commit() + db.refresh(new_session) return new_session - def fetch_all(self, user_id): + def fetch_all(self, db: Session, user_id: str) -> List[UserSession]: """Fetch all active sessions.""" - sessions = self.db.query(UserSession).filter(UserSession.is_revoked == False, UserSession.user_id == user_id).all() + sessions = db.query(UserSession).filter(UserSession.is_revoked == False, UserSession.user_id == user_id).all() return sessions - def fetch(self, user_id, session_id): + def fetch(self, db: Session, user_id: str, session_id: str) -> UserSession: """Fetch a session by its ID.""" - session = self.db.query(UserSession).filter(UserSession.id == session_id, UserSession.user_id == user_id).first() + session = db.query(UserSession).filter(UserSession.id == session_id, UserSession.user_id == user_id).first() if not session: raise HTTPException(status_code=404, detail="Session not found") return session - def delete(self, user_id, session_id): + def delete(self, db: Session, user_id: str, session_id: str): """Delete a session.""" - session = self.fetch(user_id, session_id) + session = self.fetch(db, user_id, session_id) if not session: raise HTTPException( status_code=404, detail="Session not found" ) - self.revoke_sessions([session]) + self.revoke_sessions(db, [session]) - def delete_all(self, user_id): + def delete_all(self, db: Session, user_id: str): """Revoke all sessions associated to a user""" - sessions = self.fetch_all(user_id) - self.revoke_sessions(sessions) + sessions = self.fetch_all(db, user_id) + self.revoke_sessions(db, sessions) + + +session_service = SessionService() \ No newline at end of file diff --git a/api/v1/services/user.py b/api/v1/services/user.py index c1bbb0058..b8efbf382 100644 --- a/api/v1/services/user.py +++ b/api/v1/services/user.py @@ -25,7 +25,7 @@ from api.v1.schemas import token from api.v1.services.notification_settings import notification_setting_service from api.v1.services.newsletter import NewsletterService, EmailSchema -from api.v1.services.session import SessionService +from api.v1.services.session import session_service oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") @@ -419,8 +419,7 @@ def verify_refresh_token(self, refresh_token: str, credentials_exception): try: db: Session = next(get_db()) - session_service = SessionService(db) - is_revoked = session_service.is_revoked_or_expired(refresh_token) + is_revoked = session_service.is_revoked_or_expired(db, refresh_token) if is_revoked: raise credentials_exception payload = jwt.decode( From a285ff83a91d338910865a7ee3f0105a0467ae82 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 13:07:48 +0100 Subject: [PATCH 22/29] update: add content to delete endpoints --- api/v1/routes/session.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/api/v1/routes/session.py b/api/v1/routes/session.py index b17b74d0f..645cef787 100644 --- a/api/v1/routes/session.py +++ b/api/v1/routes/session.py @@ -41,7 +41,7 @@ def get_session( session = session_service.fetch(db, current_user.id, session_id) return success_response( status_code=status.HTTP_200_OK, - message="Session retrived successfully", + message="Session retrieved successfully", data=jsonable_encoder(session, exclude={"refresh_token"}) ) @@ -59,11 +59,31 @@ def delete_session( - db: the database session - current_user: current authenticated user """ - return session_service.delete(db, current_user.id, session_id) + session_service.delete(db, current_user.id, session_id) + response_data = { + "status": "success", + "status_code": 204, + "message": "Session deleted successfully", + "data": {} + } + return JSONResponse( + status_code=status.HTTP_204_NO_CONTENT, + content=jsonable_encoder(response_data) + ) @session_router.delete('/', status_code=status.HTTP_204_NO_CONTENT) def delete_all_sessions( db: Session = Depends(get_db), current_user: User = Depends(user_service.get_current_user) ): - return session_service.delete_all(db, current_user.id) \ No newline at end of file + session_service.delete_all(db, current_user.id) + response_data = { + "status": "success", + "status_code": 204, + "message": "Sessions deleted successfully", + "data": {} + } + return JSONResponse( + status_code=status.HTTP_204_NO_CONTENT, + content=jsonable_encoder(response_data) + ) \ No newline at end of file From 53177de02c8bcd1c92d0a71cf55c276983d3a713 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 13:09:15 +0100 Subject: [PATCH 23/29] update: add session creation functionality for signin test --- tests/v1/auth/test_signin.py | 50 ++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/v1/auth/test_signin.py b/tests/v1/auth/test_signin.py index e940a79e7..c1f41a277 100644 --- a/tests/v1/auth/test_signin.py +++ b/tests/v1/auth/test_signin.py @@ -1,9 +1,12 @@ import pytest +from datetime import timedelta from fastapi.testclient import TestClient from unittest.mock import MagicMock from main import app from api.v1.models.user import User +from api.v1.models.session import UserSession from api.v1.services.user import user_service +from api.v1.services.session import session_service from api.v1.services.totp import totp_service from uuid_extensions import uuid7 from api.db.database import get_db @@ -31,6 +34,19 @@ def setup(self): created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) + self.user_session = UserSession( + id=str(uuid7()), + user_id=self.mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="192.168.1.1", # Mock IP address + location="Lagos, Nigeria", # Mock location + device="test-client", # Mock device + refresh_token=user_service.create_refresh_token(self.mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + self.mock_totp_device = TOTPDevice( user_id=self.mock_user.id, secret=pyotp.random_base32(), @@ -56,6 +72,12 @@ def test_user_login_success_without_2FA(self, monkeypatch): lambda user, db: [] ) + monkeypatch.setattr( + session_service, + "create", + lambda db, schema, user_id: self.user_session + ) + response = self.client.post( "/api/v1/auth/login", json={"email": "testuser1@gmail.com", "password": "Testpassword@123"}, @@ -84,6 +106,11 @@ def test_user_login_success_with_2FA(self, monkeypatch): "api.v1.services.organisation.organisation_service.retrieve_user_organizations", lambda user, db: [] ) + monkeypatch.setattr( + session_service, + "create", + lambda db, schema, user_id: self.user_session + ) response = self.client.post( "/api/v1/auth/login", @@ -116,6 +143,11 @@ def test_user_login_success_with_2FA_disabled(self, monkeypatch): "api.v1.services.organisation.organisation_service.retrieve_user_organizations", lambda user, db: [] ) + monkeypatch.setattr( + session_service, + "create", + lambda db, schema, user_id: self.user_session + ) response = self.client.post( "/api/v1/auth/login", json={"email": "testuser1@gmail.com", "password": "Testpassword@123"}, @@ -214,6 +246,11 @@ def test_inactive_user_login(self, monkeypatch): "api.v1.services.organisation.organisation_service.retrieve_user_organizations", lambda user, db: [] ) + monkeypatch.setattr( + session_service, + "create", + lambda db, schema, user_id: self.user_session + ) response = self.client.post( "/api/v1/auth/login", @@ -279,7 +316,20 @@ def test_user_login(db_session_mock): created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc) ) + user_session = UserSession( + id=str(uuid7()), + user_id=mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="192.168.1.1", # Mock IP address + location="Lagos, Nigeria", # Mock location + device="test-client", # Mock device + refresh_token=user_service.create_refresh_token(mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) db_session_mock.query.return_value.filter.return_value.first.return_value = mock_user + db_session_mock.query.return_value.filter.return_value.first.return_value = user_session # Login with mock user details login = client.post("/api/v1/auth/login", json={ From c30be7146383fcb82848435ab515e678f65a60d8 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 13:10:57 +0100 Subject: [PATCH 24/29] test: add test for user sessions get endpoints --- tests/v1/session/test_get_session.py | 105 +++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 tests/v1/session/test_get_session.py diff --git a/tests/v1/session/test_get_session.py b/tests/v1/session/test_get_session.py new file mode 100644 index 000000000..697217e7d --- /dev/null +++ b/tests/v1/session/test_get_session.py @@ -0,0 +1,105 @@ +import time +import pytest +from datetime import datetime, timezone, timedelta +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch +from uuid_extensions import uuid7 + +from main import app +from api.db.database import get_db +from api.v1.models.user import User +from api.v1.models.session import UserSession +from api.v1.services.user import user_service + +client = TestClient(app) + +@pytest.fixture +def mock_db(): + """Mock database object.""" + mock_db = MagicMock() + yield mock_db + +@pytest.fixture(autouse=True) +def client(mock_db): + """Override the get_db dependency with the mock.""" + def get_db_override(): + yield mock_db + app.dependency_overrides[get_db] = get_db_override + client = TestClient(app) + yield client + +@pytest.fixture(autouse=True) +def mock_user(): + """Mock user object.""" + return User( + id=str(uuid7()), + email="testuser1@gmail.com", + password=user_service.hash_password("Testpassword@123"), + first_name="Test", + last_name="User", + is_active=True, + is_superadmin=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + +@pytest.fixture(autouse=True) +def mock_session_1(mock_user): + """Mock session object.""" + return UserSession( + id=str(uuid7()), + user_id=mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="127.0.0.1", + location="Lagos, Nigeria", + device="test-client", + refresh_token=user_service.create_refresh_token(mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + +@pytest.fixture(autouse=True) +def mock_session_2(mock_user): + time.sleep(1) + """Mock session object.""" + return UserSession( + id=str(uuid7()), + user_id=mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="127.0.1.1", + location="Lagos, Nigeria", + device="test-client-2", + refresh_token=user_service.create_refresh_token(mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + +def test_get_session(client, mock_db, mock_user, mock_session_1): + """Test getting a session.""" + app.dependency_overrides[user_service.get_current_user] = lambda: mock_user + mock_db.query().filter().first.return_value = mock_session_1 + response = client.get( + f"/api/v1/sessions/{mock_session_1.id}", + headers={"Authorization": "Beare token"} + ) + + assert response.status_code == 200 + assert response.json()['message'] == "Session retrieved successfully" + assert response.json()['status'] == "success" + +def test_get_sessions(client , mock_db, mock_user, mock_session_1, mock_session_2): + """Test getting all sessions.""" + app.dependency_overrides[user_service.get_current_user] = lambda: mock_user + mock_db.query().filter().all.return_value = [mock_session_1, mock_session_2] + response = client.get( + "/api/v1/sessions/", + headers={"Authorization": "Beare token"} + ) + print(f"Response status code: {response.status_code}") + print(f"Response JSON: {response.json()}") + assert response.status_code == 200 + assert response.json()['message'] == "Sessions retrieved successfully" + assert response.json()['status'] == "success" \ No newline at end of file From 928c6bcb86392f9fc8c54ddae21d3c3037a56b15 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 13:11:30 +0100 Subject: [PATCH 25/29] test: add test for user sessions delete endpoints --- tests/v1/session/test_delete_session.py | 102 ++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/v1/session/test_delete_session.py diff --git a/tests/v1/session/test_delete_session.py b/tests/v1/session/test_delete_session.py new file mode 100644 index 000000000..e5dab6c18 --- /dev/null +++ b/tests/v1/session/test_delete_session.py @@ -0,0 +1,102 @@ +import pytest +import time +from datetime import datetime, timezone, timedelta +from fastapi.testclient import TestClient +from unittest.mock import MagicMock +from uuid_extensions import uuid7 + +from main import app +from api.db.database import get_db +from api.v1.models.user import User +from api.v1.models.session import UserSession +from api.v1.services.user import user_service +from api.v1.services.session import session_service + + + +@pytest.fixture +def mock_db(): + mock_db = MagicMock() + yield mock_db + + +@pytest.fixture(autouse=True) +def client(mock_db): + """Override the get_db dependency with the mock.""" + def get_db_override(): + yield mock_db + app.dependency_overrides[get_db] = get_db_override + client = TestClient(app) + yield client + +@pytest.fixture(autouse=True) +def mock_user(): + """Mock user object.""" + return User( + id=str(uuid7()), + email="testuser1@gmail.com", + password=user_service.hash_password("Testpassword@123"), + first_name="Test", + last_name="User", + is_active=True, + is_superadmin=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + +@pytest.fixture(autouse=True) +def mock_session_1(mock_user): + """Mock session object.""" + return UserSession( + id=str(uuid7()), + user_id=mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="127.0.0.1", + location="Lagos, Nigeria", + device="test-client", + refresh_token=user_service.create_refresh_token(mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + +@pytest.fixture(autouse=True) +def mock_session_2(mock_user): + time.sleep(1) + """Mock session object.""" + return UserSession( + id=str(uuid7()), + user_id=mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="127.0.1.1", + location="Lagos, Nigeria", + device="test-client-2", + refresh_token=user_service.create_refresh_token(mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + +def test_delete_sessions(client, mock_db, mock_user, mock_session_1, mock_session_2): + """Test deleting all sessions.""" + mock_db.query().filter().all.return_value = [mock_session_1, mock_session_2] + app.dependency_overrides[user_service.get_current_user] = lambda: mock_user + + response = client.delete(f"/api/v1/sessions/", headers={"Authorization": "Bearer token"}) + + assert response.status_code == 204 + assert response.json()['message'] == "Sessions deleted successfully" + assert response.json()['status'] == "success" + +def test_delete_session(client, mock_db, mock_user, mock_session_1): + """Test deleting a session.""" + mock_db.query().filter().first.return_value = mock_session_1 + app.dependency_overrides[user_service.get_current_user] = lambda: mock_user + + response = client.delete( + f"/api/v1/sessions/{mock_session_1.id}", + headers={"Authorization": "Bearer token"}) + + assert response.status_code == 204 + assert response.json()['message'] == "Session deleted successfully" + assert response.json()['status'] == "success" \ No newline at end of file From fe917274b62eda1cc668d6aba062dadce3528867 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 13:12:46 +0100 Subject: [PATCH 26/29] test: test that different refresh token is generated and first is revoked --- .../v1/session/test_user_session_creation.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/v1/session/test_user_session_creation.py diff --git a/tests/v1/session/test_user_session_creation.py b/tests/v1/session/test_user_session_creation.py new file mode 100644 index 000000000..24bd26a7d --- /dev/null +++ b/tests/v1/session/test_user_session_creation.py @@ -0,0 +1,100 @@ +import time +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch +from uuid_extensions import uuid7 +from datetime import datetime, timezone, timedelta + +from main import app + +from api.v1.services.totp import totp_service +from api.v1.services.session import get_db +from api.v1.models.user import User +from api.v1.services.user import user_service +from api.v1.models.session import UserSession +from api.v1.services.session import session_service + + +class TestUserSessionCreation: + @pytest.fixture(autouse=True) + def setup(self): + self.mock_db = MagicMock() + app.dependency_overrides[get_db] = lambda: self.mock_db + self.client = TestClient(app) + + # Mock user creation to return a valid user object + self.mock_user = User( + id=str(uuid7()), + email="testuser1@gmail.com", + password=user_service.hash_password("Testpassword@123"), + first_name="Test", + last_name="User", + is_active=True, + is_superadmin=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc) + ) + + # Mock session creation + self.mock_user_session = UserSession( + id=str(uuid7()), + user_id=self.mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="192.168.1.1", + location="Lagos, Nigeria", + device="test-client", + refresh_token=user_service.create_refresh_token(self.mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + time.sleep(1) + + def test_previous_token_is_revoked_when_login_from_same_ip_and_user_agent(self, monkeypatch): + """Test that previous token is revoked when login from same IP and user agent.""" + new_mock_session = UserSession( + id=str(uuid7()), + user_id=self.mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="192.168.1.1", + location="Lagos, Nigeria", + device="test-client", + refresh_token=user_service.create_refresh_token(self.mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + monkeypatch.setattr( + user_service, + "authenticate_user", + lambda db, email, password: self.mock_user + ) + monkeypatch.setattr( + "api.v1.services.organisation.organisation_service.retrieve_user_organizations", + lambda user, db: [] + ) + + monkeypatch.setattr( + totp_service, + "check_2fa_status_and_verify", + lambda db, user_id, schema: True + ) + def mock_create_session(db, schema, user_id): + self.mock_user_session.is_revoked = True + return new_mock_session + + monkeypatch.setattr( + session_service, + "create", + mock_create_session + ) + + response = self.client.post( + "/api/v1/auth/login", + json={"email": "testuser1@gmail.com", "password": "Testpassword@123"}, + ) + + assert response.status_code == 200 + assert self.mock_user_session.refresh_token != new_mock_session.refresh_token + assert self.mock_user_session.is_revoked is True + From ef66ce279ca02c7a1f9df63d6030111aa52f3f84 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 13:27:10 +0100 Subject: [PATCH 27/29] test: add session creation to test_signin --- tests/v1/auth/test_signin.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/v1/auth/test_signin.py b/tests/v1/auth/test_signin.py index c1f41a277..9b474fa7b 100644 --- a/tests/v1/auth/test_signin.py +++ b/tests/v1/auth/test_signin.py @@ -34,7 +34,7 @@ def setup(self): created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) - self.user_session = UserSession( + self.mock_user_session = UserSession( id=str(uuid7()), user_id=self.mock_user.id, expires_at=datetime.now(timezone.utc) + timedelta(days=1), @@ -75,7 +75,7 @@ def test_user_login_success_without_2FA(self, monkeypatch): monkeypatch.setattr( session_service, "create", - lambda db, schema, user_id: self.user_session + lambda db, schema, user_id: self.mock_user_session ) response = self.client.post( @@ -109,7 +109,7 @@ def test_user_login_success_with_2FA(self, monkeypatch): monkeypatch.setattr( session_service, "create", - lambda db, schema, user_id: self.user_session + lambda db, schema, user_id: self.mock_user_session ) response = self.client.post( @@ -146,7 +146,7 @@ def test_user_login_success_with_2FA_disabled(self, monkeypatch): monkeypatch.setattr( session_service, "create", - lambda db, schema, user_id: self.user_session + lambda db, schema, user_id: self.mock_user_session ) response = self.client.post( "/api/v1/auth/login", @@ -180,6 +180,11 @@ def mock_check_2fa_status_and_verify(): "check_2fa_status_and_verify", lambda db, user_id, schema: mock_check_2fa_status_and_verify() ) + monkeypatch.setattr( + session_service, + "create", + lambda db, schema, user_id: self.mock_user_session + ) response = self.client.post( "/api/v1/auth/login", @@ -249,7 +254,7 @@ def test_inactive_user_login(self, monkeypatch): monkeypatch.setattr( session_service, "create", - lambda db, schema, user_id: self.user_session + lambda db, schema, user_id: self.mock_user_session ) response = self.client.post( @@ -316,7 +321,7 @@ def test_user_login(db_session_mock): created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc) ) - user_session = UserSession( + mock_user_session = UserSession( id=str(uuid7()), user_id=mock_user.id, expires_at=datetime.now(timezone.utc) + timedelta(days=1), @@ -329,7 +334,9 @@ def test_user_login(db_session_mock): updated_at=datetime.now(timezone.utc), ) db_session_mock.query.return_value.filter.return_value.first.return_value = mock_user - db_session_mock.query.return_value.filter.return_value.first.return_value = user_session + + # Mock the SessionService.create method + session_service.create = MagicMock(return_value=mock_user_session) # Login with mock user details login = client.post("/api/v1/auth/login", json={ From 1d8e3f2422a94feadf8c8f2b0b6c5b8fac0b23fe Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 14:08:16 +0100 Subject: [PATCH 28/29] test: add user session creation to test_upload_profile_image --- tests/v1/profile/test_upload_profile_image.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/tests/v1/profile/test_upload_profile_image.py b/tests/v1/profile/test_upload_profile_image.py index 96ecaecd9..636d45d91 100644 --- a/tests/v1/profile/test_upload_profile_image.py +++ b/tests/v1/profile/test_upload_profile_image.py @@ -4,11 +4,13 @@ from main import app from api.v1.models.user import User from api.v1.models.profile import Profile +from api.v1.models.session import UserSession from api.v1.services.user import user_service +from api.v1.services.session import session_service from uuid_extensions import uuid7 from api.db.database import get_db from fastapi import status -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta client = TestClient(app) @@ -73,7 +75,21 @@ def create_mock_user_profile(mock_user_service, mock_db_session): @pytest.mark.usefixtures("mock_db_session", "mock_user_service") def test_errors(mock_user_service, mock_db_session): """Test for errors in profile creation""" - create_mock_user(mock_user_service, mock_db_session) + mock_user = create_mock_user(mock_user_service, mock_db_session) + mock_user_session = UserSession( + id=str(uuid7()), + user_id=mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="192.168.1.1", # Mock IP address + location="Lagos, Nigeria", # Mock location + device="test-client", # Mock device + refresh_token=user_service.create_refresh_token(mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session_service.create = MagicMock(return_value=mock_user_session) + login = client.post(LOGIN_ENDPOINT, json={ "email": "testuser@gmail.com", "password": "Testpassword@123" @@ -97,7 +113,20 @@ def test_errors(mock_user_service, mock_db_session): @pytest.mark.usefixtures("mock_db_session", "mock_user_service") def test_user_profile_upload(mock_user_service, mock_db_session): - create_mock_user(mock_user_service, mock_db_session) + mock_user = create_mock_user(mock_user_service, mock_db_session) + mock_user_session = UserSession( + id=str(uuid7()), + user_id=mock_user.id, + expires_at=datetime.now(timezone.utc) + timedelta(days=1), + ip_address="192.168.1.1", # Mock IP address + location="Lagos, Nigeria", # Mock location + device="test-client", # Mock device + refresh_token=user_service.create_refresh_token(mock_user.id), + is_revoked=False, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + session_service.create = MagicMock(return_value=mock_user_session) login = client.post(LOGIN_ENDPOINT, json={ "email": "testuser@gmail.com", "password": "Testpassword@123" From ef916aa75e4bab35426a76c6d077c143cd8ba889 Mon Sep 17 00:00:00 2001 From: tha_orakkle Date: Mon, 3 Mar 2025 14:10:42 +0100 Subject: [PATCH 29/29] fix: fix no fixture session found error --- tests/v1/social_auth/test_facebook_auth.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/v1/social_auth/test_facebook_auth.py b/tests/v1/social_auth/test_facebook_auth.py index cc5a1babe..c57044750 100644 --- a/tests/v1/social_auth/test_facebook_auth.py +++ b/tests/v1/social_auth/test_facebook_auth.py @@ -1,6 +1,6 @@ import pytest from fastapi.testclient import TestClient -from tests.database import session, client +from tests.database import client from api.v1.models import * from api.db.database import get_db from main import app @@ -10,6 +10,11 @@ INVALID_ACCESS_TOKEN = "invalid_token" +@pytest.fixture() +def session(): + from tests.database import session as test_session + return test_session + class MockResponse: """This class will be used to mock the response of the Facebook API.""" @@ -20,7 +25,6 @@ def __init__(self, content, status_code): def json(self): return self.content - class MockTestClient(TestClient): """This class will be used to mock the client of the Facebook API."""