diff --git a/src/teuthology_api/routes/login.py b/src/teuthology_api/routes/login.py index c4916a2..3deafc7 100644 --- a/src/teuthology_api/routes/login.py +++ b/src/teuthology_api/routes/login.py @@ -1,17 +1,14 @@ import logging import os -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, HTTPException, Request, Depends from fastapi.responses import RedirectResponse from dotenv import load_dotenv -import httpx +from teuthology_api.services.auth import get_github_auth_service, AuthService load_dotenv() GH_CLIENT_ID = os.getenv("GH_CLIENT_ID") -GH_CLIENT_SECRET = os.getenv("GH_CLIENT_SECRET") GH_AUTHORIZATION_BASE_URL = os.getenv("GH_AUTHORIZATION_BASE_URL") -GH_TOKEN_URL = os.getenv("GH_TOKEN_URL") -GH_FETCH_MEMBERSHIP_URL = os.getenv("GH_FETCH_MEMBERSHIP_URL") PULPITO_URL = os.getenv("PULPITO_URL") log = logging.getLogger(__name__) @@ -38,56 +35,28 @@ async def github_login(): @router.get("/callback", status_code=200) -async def handle_callback(code: str, request: Request): +async def handle_callback(code: str, request: Request, auth_service: AuthService = Depends(get_github_auth_service)): """ Call back route after user login & authorize the app for access. """ - params = { - "client_id": GH_CLIENT_ID, - "client_secret": GH_CLIENT_SECRET, - "code": code, - } - headers = {"Accept": "application/json"} - async with httpx.AsyncClient() as client: - response_token = await client.post( - url=GH_TOKEN_URL, params=params, headers=headers - ) - log.info(response_token.json()) - response_token_dic = dict(response_token.json()) - token = response_token_dic.get("access_token") - if response_token_dic.get("error") or not token: - log.error("The code is incorrect or expired.") - raise HTTPException( - status_code=401, detail="The code is incorrect or expired." - ) - headers = {"Authorization": "token " + token} - response_org = await client.get(url=GH_FETCH_MEMBERSHIP_URL, headers=headers) - log.info(response_org.json()) - if response_org.status_code == 404: - log.error("User is not part of the Ceph Organization") - raise HTTPException( - status_code=404, - detail="User is not part of the Ceph Organization, please contact ", - ) - if response_org.status_code == 403: - log.error("The application doesn't have permission to view github org") - raise HTTPException( - status_code=403, - detail="The application doesn't have permission to view github org", - ) - response_org_dic = dict(response_org.json()) - data = { - "id": response_org_dic.get("user", {}).get("id"), - "username": response_org_dic.get("user", {}).get("login"), - "state": response_org_dic.get("state"), - "role": response_org_dic.get("role"), + token = await auth_service._get_token(status_code=code) + + response_org_dict = await auth_service._get_org(token=token) + + data = { + "id": response_org_dict.get("user", {}).get("id"), + "username": response_org_dict.get("user", {}).get("login"), + "state": response_org_dict.get("state"), + "role": response_org_dict.get("role"), "access_token": token, } - request.session["user"] = data + + request.session["user"] = data + cookie_data = { "username": data["username"], - "avatar_url": response_org_dic.get("user", {}).get("avatar_url"), + "avatar_url": response_org_dict.get("user", {}).get("avatar_url"), } cookie = "; ".join( [f"{str(key)}={str(value)}" for key, value in cookie_data.items()] diff --git a/src/teuthology_api/services/auth.py b/src/teuthology_api/services/auth.py new file mode 100644 index 0000000..5cd67de --- /dev/null +++ b/src/teuthology_api/services/auth.py @@ -0,0 +1,114 @@ +import abc +import os +import httpx +import logging +from dotenv import load_dotenv +from fastapi import HTTPException + +load_dotenv() +log = logging.getLogger(__name__) + +class AuthService(abc.ABC): + @abc.abstractmethod + async def _get_token(self, status_code: int) -> dict: + """Returns a dict of response JSON from GH.""" + pass + + @abc.abstractmethod + async def _get_org(self, token: str) -> dict: + """Returns org info of user.""" + pass + +class AuthServiceGH(AuthService): + + def __init__(self): + self.GH_CLIENT_ID = os.getenv("GH_CLIENT_ID") + self.GH_CLIENT_SECRET = os.getenv("GH_CLIENT_SECRET") + self.GH_AUTHORIZATION_BASE_URL = os.getenv("GH_AUTHORIZATION_BASE_URL") + self.GH_TOKEN_URL = os.getenv("GH_TOKEN_URL") + self.GH_FETCH_MEMBERSHIP_URL = os.getenv("GH_FETCH_MEMBERSHIP_URL") + self.PULPITO_URL = os.getenv("PULPITO_URL") + + async def _get_token(self, status_code: int) -> str: + params = { + "client_id": self.GH_CLIENT_ID, + "client_secret": self.GH_CLIENT_SECRET, + "code": status_code, + } + headers = {"Accept": "application/json"} + async with httpx.AsyncClient as client: + response_token = await client.post( + url=self.GH_TOKEN_URL, params=params, headers=headers + ) + log.info(response_token.json()) + response_token_dict = dict(response_token.json()) + token = response_token_dict.get("access_token") + if response_token_dict.get("error") or not token: + log.error("The code is incorrect or expired.") + raise HTTPException( + status_code=401, detail="The code is incorrect or expired." + ) + return token + + async def _get_org(self, token: str) -> dict: + headers = {"Authorization": "token " + token} + async with httpx.AsyncClient as client: + response_org = await client.get(url=self.GH_FETCH_MEMBERSHIP_URL, headers=headers) + response_org_dict = dict(response_org.json()) + log.info(response_org) + if response_org.status_code == 404: + log.error("User is not part of the Ceph Organization") + raise HTTPException( + status_code=404, + detail="User is not part of the Ceph Organization, please contact .", + ) + if response_org.status_code == 403: + log.error("The application doesn't have permission to view github org.") + raise HTTPException( + status_code=403, + detail="The application doesn't have permission to view github org.", + ) + return response_org_dict + +class AuthServiceMock(AuthService): + async def _get_token(self, status_code: int) -> dict: + if status_code == 200: + return "admin" + elif status_code == 403: + return "user" + elif status_code == 404: + return "" + elif status_code == 500: + raise HTTPException( + status_code=401, detail="The code is incorrect or expired." + ) + else: + return "" + + async def _get_org(self, token: str) -> dict: + if token == "admin": + return { + "id": "admin_id", + "username": "admin", + "state": "state", + "role": "admin" + } + elif token == "": + log.error("The application doesn't have permission to view github org.") + raise HTTPException( + status_code=403, + detail="The application doesn't have permission to view github org.", + ) + else: + log.error("User is not part of the Ceph Organization") + raise HTTPException( + status_code=404, + detail="User is not part of the Ceph Organization, please contact .", + ) + + +def get_github_auth_service(): + return AuthServiceGH() + +def get_mock_auth_service(): + return AuthServiceMock()