Skip to content

Commit

Permalink
Optimized user auth for auth service (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Dec 7, 2024
1 parent 8c04bcb commit 3b2b45d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 39 deletions.
4 changes: 2 additions & 2 deletions backend/app/admin/crud/crud_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def create(self, db: AsyncSession, obj: RegisterUserParam, *, social: bool
"""
if not social:
salt = bcrypt.gensalt()
obj.password = get_hash_password(f'{obj.password}', salt)
obj.password = get_hash_password(obj.password, salt)
dict_obj = obj.model_dump()
dict_obj.update({'is_staff': True, 'salt': salt})
else:
Expand All @@ -90,7 +90,7 @@ async def add(self, db: AsyncSession, obj: AddUserParam) -> None:
:return:
"""
salt = bcrypt.gensalt()
obj.password = get_hash_password(f'{obj.password}', salt)
obj.password = get_hash_password(obj.password, salt)
dict_obj = obj.model_dump(exclude={'roles'})
dict_obj.update({'salt': salt})
new_user = self.model(**dict_obj)
Expand Down
69 changes: 34 additions & 35 deletions backend/app/admin/service/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from fastapi import Request, Response
from fastapi.security import HTTPBasicCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.background import BackgroundTask, BackgroundTasks

from backend.app.admin.conf import admin_settings
Expand Down Expand Up @@ -29,42 +30,40 @@

class AuthService:
@staticmethod
async def swagger_login(*, obj: HTTPBasicCredentials) -> tuple[str, User]:
async def user_verify(db: AsyncSession, username: str, password: str) -> User:
user = await user_dao.get_by_username(db, username)
if not user:
raise errors.NotFoundError(msg='用户名或密码有误')
elif not password_verify(password, user.password):
raise errors.AuthorizationError(msg='用户名或密码有误')
elif not user.status:
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
return user

async def swagger_login(self, *, obj: HTTPBasicCredentials) -> tuple[str, User]:
async with async_db_session.begin() as db:
current_user = await user_dao.get_by_username(db, obj.username)
if not current_user:
raise errors.NotFoundError(msg='用户名或密码有误')
elif not password_verify(f'{obj.password}', current_user.password):
raise errors.AuthorizationError(msg='用户名或密码有误')
elif not current_user.status:
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
access_token = await create_access_token(str(current_user.id), current_user.is_multi_login)
user = await self.user_verify(db, obj.username, obj.password)
user_id = user.id
a_token = await create_access_token(str(user_id), user.is_multi_login)
await user_dao.update_login_time(db, obj.username)
return access_token.access_token, current_user
return a_token.access_token, user

@staticmethod
async def login(
*, request: Request, response: Response, obj: AuthLoginParam, background_tasks: BackgroundTasks
self, *, request: Request, response: Response, obj: AuthLoginParam, background_tasks: BackgroundTasks
) -> GetLoginToken:
async with async_db_session.begin() as db:
try:
current_user = await user_dao.get_by_username(db, obj.username)
if not current_user:
raise errors.NotFoundError(msg='用户名或密码有误')
user_uuid = current_user.uuid
username = current_user.username
if not password_verify(obj.password, current_user.password):
raise errors.AuthorizationError(msg='用户名或密码有误')
elif not current_user.status:
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
user = await self.user_verify(db, obj.username, obj.password)
user_id = user.id
user_uuid = user.uuid
username = user.username
captcha_code = await redis_client.get(f'{admin_settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}')
if not captcha_code:
raise errors.AuthorizationError(msg='验证码失效,请重新获取')
if captcha_code.lower() != obj.captcha.lower():
raise errors.CustomError(error=CustomErrorCode.CAPTCHA_ERROR)
current_user_id = current_user.id
access_token = await create_access_token(str(current_user_id), current_user.is_multi_login)
refresh_token = await create_refresh_token(str(current_user_id), current_user.is_multi_login)
a_token = await create_access_token(str(user_id), user.is_multi_login)
r_token = await create_refresh_token(str(user_id), user.is_multi_login)
except errors.NotFoundError as e:
raise errors.NotFoundError(msg=e.msg)
except (errors.AuthorizationError, errors.CustomError) as e:
Expand Down Expand Up @@ -100,16 +99,16 @@ async def login(
await user_dao.update_login_time(db, obj.username)
response.set_cookie(
key=settings.COOKIE_REFRESH_TOKEN_KEY,
value=refresh_token.refresh_token,
value=r_token.refresh_token,
max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
expires=timezone.f_utc(refresh_token.refresh_token_expire_time),
expires=timezone.f_utc(r_token.refresh_token_expire_time),
httponly=True,
)
await db.refresh(current_user)
await db.refresh(user)
data = GetLoginToken(
access_token=access_token.access_token,
access_token_expire_time=access_token.access_token_expire_time,
user=current_user, # type: ignore
access_token=a_token.access_token,
access_token_expire_time=a_token.access_token_expire_time,
user=user, # type: ignore
)
return data

Expand All @@ -125,17 +124,17 @@ async def new_token(*, request: Request, response: Response) -> GetNewToken:
if request.user.id != user_id:
raise errors.TokenError(msg='Refresh Token 无效')
async with async_db_session() as db:
current_user = await user_dao.get(db, user_id)
if not current_user:
user = await user_dao.get(db, user_id)
if not user:
raise errors.NotFoundError(msg='用户名或密码有误')
elif not current_user.status:
elif not user.status:
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
current_token = get_token(request)
new_token = await create_new_token(
sub=str(current_user.id),
sub=str(user.id),
token=current_token,
refresh_token=refresh_token,
multi_login=current_user.is_multi_login,
multi_login=user.is_multi_login,
)
response.set_cookie(
key=settings.COOKIE_REFRESH_TOKEN_KEY,
Expand Down
4 changes: 2 additions & 2 deletions backend/app/admin/service/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ async def add(*, request: Request, obj: AddUserParam) -> None:
async def pwd_reset(*, request: Request, obj: ResetPasswordParam) -> int:
async with async_db_session.begin() as db:
user = await user_dao.get(db, request.user.id)
if not password_verify(f'{obj.old_password}', user.password):
if not password_verify(obj.old_password, user.password):
raise errors.ForbiddenError(msg='原密码错误')
np1 = obj.new_password
np2 = obj.confirm_password
if np1 != np2:
raise errors.ForbiddenError(msg='密码输入不一致')
new_pwd = get_hash_password(f'{obj.new_password}', user.salt)
new_pwd = get_hash_password(obj.new_password, user.salt)
count = await user_dao.reset_password(db, request.user.id, new_pwd)
key_prefix = [
f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}',
Expand Down

0 comments on commit 3b2b45d

Please sign in to comment.