Skip to content

Commit 2ad576d

Browse files
refactor(business auth): Consolidate JWT token utilities
- Simplify views using `CreateAPIView` and remove redundant mixins - Add `@transaction.atomic` for atomic database operations during company creation - Extract JWT token generation into reusable `generate_company_tokens()` utility - Extract JWT token version bump into reusable `bump_company_token_version()` utility - Validate UUID format rigorously in token refresh flow - Add documentation in views - Remove redundant imports
1 parent 09472fb commit 2ad576d

File tree

4 files changed

+97
-122
lines changed

4 files changed

+97
-122
lines changed

promo_code/business/serializers.py

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import uuid
22

33
import django.contrib.auth.password_validation
4-
import django.core.exceptions
54
import django.core.validators
5+
import django.db.transaction
66
import pycountry
77
import rest_framework.exceptions
88
import rest_framework.serializers
@@ -11,7 +11,10 @@
1111
import rest_framework_simplejwt.tokens
1212

1313
import business.constants
14+
import business.models
1415
import business.models as business_models
16+
import business.utils.auth
17+
import business.utils.tokens
1518
import business.validators
1619

1720

@@ -21,9 +24,9 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
2124
write_only=True,
2225
required=True,
2326
validators=[django.contrib.auth.password_validation.validate_password],
27+
style={'input_type': 'password'},
2428
min_length=business.constants.COMPANY_PASSWORD_MIN_LENGTH,
2529
max_length=business.constants.COMPANY_PASSWORD_MAX_LENGTH,
26-
style={'input_type': 'password'},
2730
)
2831
name = rest_framework.serializers.CharField(
2932
required=True,
@@ -44,30 +47,18 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
4447

4548
class Meta:
4649
model = business_models.Company
47-
fields = (
48-
'id',
49-
'name',
50-
'email',
51-
'password',
52-
)
50+
fields = ('id', 'name', 'email', 'password')
5351

52+
@django.db.transaction.atomic
5453
def create(self, validated_data):
55-
try:
56-
company = business_models.Company.objects.create_company(
57-
email=validated_data['email'],
58-
name=validated_data['name'],
59-
password=validated_data['password'],
60-
)
61-
company.token_version += 1
62-
company.save()
63-
return company
64-
except django.core.exceptions.ValidationError as e:
65-
raise rest_framework.serializers.ValidationError(e.messages)
54+
company = business_models.Company.objects.create_company(
55+
**validated_data,
56+
)
6657

58+
return business.utils.auth.bump_company_token_version(company)
6759

68-
class CompanySignInSerializer(
69-
rest_framework.serializers.Serializer,
70-
):
60+
61+
class CompanySignInSerializer(rest_framework.serializers.Serializer):
7162
email = rest_framework.serializers.EmailField(required=True)
7263
password = rest_framework.serializers.CharField(
7364
required=True,
@@ -80,16 +71,15 @@ def validate(self, attrs):
8071
password = attrs.get('password')
8172

8273
if not email or not password:
83-
raise rest_framework.exceptions.ValidationError(
84-
{'detail': 'Both email and password are required'},
85-
code='required',
74+
raise rest_framework.serializers.ValidationError(
75+
'Both email and password are required.',
8676
)
8777

8878
try:
8979
company = business_models.Company.objects.get(email=email)
9080
except business_models.Company.DoesNotExist:
9181
raise rest_framework.serializers.ValidationError(
92-
'Invalid credentials',
82+
'Invalid credentials.',
9383
)
9484

9585
if not company.is_active or not company.check_password(password):
@@ -98,53 +88,55 @@ def validate(self, attrs):
9888
code='authentication_failed',
9989
)
10090

91+
attrs['company'] = company
10192
return attrs
10293

10394

10495
class CompanyTokenRefreshSerializer(
10596
rest_framework_simplejwt.serializers.TokenRefreshSerializer,
10697
):
10798
def validate(self, attrs):
99+
attrs = super().validate(attrs)
108100
refresh = rest_framework_simplejwt.tokens.RefreshToken(
109101
attrs['refresh'],
110102
)
111-
user_type = refresh.payload.get('user_type', 'user')
103+
company = self.get_active_company_from_token(refresh)
104+
105+
company = business.utils.auth.bump_company_token_version(company)
112106

113-
if user_type != 'company':
107+
return business.utils.tokens.generate_company_tokens(company)
108+
109+
def get_active_company_from_token(self, token):
110+
if token.payload.get('user_type') != 'company':
114111
raise rest_framework_simplejwt.exceptions.InvalidToken(
115112
'This refresh endpoint is for company tokens only',
116113
)
117114

118-
company_id = refresh.payload.get('company_id')
119-
if not company_id:
115+
company_id = token.payload.get('company_id')
116+
try:
117+
company_uuid = uuid.UUID(company_id)
118+
except (TypeError, ValueError):
120119
raise rest_framework_simplejwt.exceptions.InvalidToken(
121-
'Company ID missing in token',
120+
'Invalid or missing company_id in token',
122121
)
123122

124123
try:
125-
company = business_models.Company.objects.get(
126-
id=uuid.UUID(company_id),
124+
company = business.models.Company.objects.get(
125+
id=company_uuid,
126+
is_active=True,
127127
)
128-
except business_models.Company.DoesNotExist:
128+
except business.models.Company.DoesNotExist:
129129
raise rest_framework_simplejwt.exceptions.InvalidToken(
130-
'Company not found',
130+
'Company not found or inactive',
131131
)
132132

133-
token_version = refresh.payload.get('token_version', 0)
133+
token_version = token.payload.get('token_version', 0)
134134
if company.token_version != token_version:
135135
raise rest_framework_simplejwt.exceptions.InvalidToken(
136136
'Token is blacklisted',
137137
)
138138

139-
new_refresh = rest_framework_simplejwt.tokens.RefreshToken()
140-
new_refresh['user_type'] = 'company'
141-
new_refresh['company_id'] = str(company.id)
142-
new_refresh['token_version'] = company.token_version
143-
144-
return {
145-
'access': str(new_refresh.access_token),
146-
'refresh': str(new_refresh),
147-
}
139+
return company
148140

149141

150142
class TargetSerializer(rest_framework.serializers.Serializer):

promo_code/business/utils/auth.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import business.models
2+
3+
4+
def bump_company_token_version(company):
5+
"""
6+
Increment token_version, save it, and return the fresh instance.
7+
"""
8+
company = business.models.Company.objects.select_for_update().get(
9+
id=company.id,
10+
)
11+
company.token_version += 1
12+
company.save(update_fields=['token_version'])
13+
return company

promo_code/business/utils/tokens.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import rest_framework_simplejwt.tokens
2+
3+
4+
def generate_company_tokens(company):
5+
"""
6+
Generate JWT tokens for a company.
7+
"""
8+
refresh = rest_framework_simplejwt.tokens.RefreshToken()
9+
refresh['user_type'] = 'company'
10+
refresh['company_id'] = str(company.id)
11+
refresh['token_version'] = company.token_version
12+
13+
access = refresh.access_token
14+
access['user_type'] = 'company'
15+
access['company_id'] = str(company.id)
16+
17+
return {
18+
'access': str(access),
19+
'refresh': str(refresh),
20+
}

promo_code/business/views.py

Lines changed: 27 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,111 +4,61 @@
44
import rest_framework.generics
55
import rest_framework.permissions
66
import rest_framework.response
7-
import rest_framework.serializers
87
import rest_framework.status
9-
import rest_framework_simplejwt.exceptions
10-
import rest_framework_simplejwt.tokens
118
import rest_framework_simplejwt.views
129

1310
import business.models
1411
import business.pagination
1512
import business.permissions
1613
import business.serializers
17-
import core.views
14+
import business.utils.auth
15+
import business.utils.tokens
1816

1917

20-
class CompanySignUpView(
21-
core.views.BaseCustomResponseMixin,
22-
rest_framework.generics.CreateAPIView,
23-
):
24-
serializer_class = business.serializers.CompanySignUpSerializer
25-
26-
def post(self, request):
27-
try:
28-
serializer = business.serializers.CompanySignUpSerializer(
29-
data=request.data,
30-
)
31-
serializer.is_valid(raise_exception=True)
32-
except (
33-
rest_framework.serializers.ValidationError,
34-
rest_framework_simplejwt.exceptions.TokenError,
35-
) as e:
36-
if isinstance(e, rest_framework.serializers.ValidationError):
37-
return self.handle_validation_error()
18+
class CompanySignUpView(rest_framework.generics.CreateAPIView):
19+
"""
20+
Company registration endpoint that returns JWT tokens.
21+
"""
3822

39-
raise rest_framework_simplejwt.exceptions.InvalidToken(str(e))
23+
serializer_class = business.serializers.CompanySignUpSerializer
4024

25+
def create(self, request, *args, **kwargs):
26+
serializer = self.get_serializer(data=request.data)
27+
serializer.is_valid(raise_exception=True)
4128
company = serializer.save()
4229

43-
refresh = rest_framework_simplejwt.tokens.RefreshToken()
44-
refresh['user_type'] = 'company'
45-
refresh['company_id'] = str(company.id)
46-
refresh['token_version'] = company.token_version
47-
48-
access_token = refresh.access_token
49-
access_token['user_type'] = 'company'
50-
access_token['company_id'] = str(company.id)
51-
refresh['token_version'] = company.token_version
52-
53-
response_data = {
54-
'access': str(access_token),
55-
'refresh': str(refresh),
56-
}
57-
5830
return rest_framework.response.Response(
59-
response_data,
31+
business.utils.tokens.generate_company_tokens(company),
6032
status=rest_framework.status.HTTP_200_OK,
6133
)
6234

6335

64-
class CompanySignInView(
65-
core.views.BaseCustomResponseMixin,
66-
rest_framework_simplejwt.views.TokenObtainPairView,
67-
):
68-
serializer_class = business.serializers.CompanySignInSerializer
69-
70-
def post(self, request):
71-
try:
72-
serializer = business.serializers.CompanySignInSerializer(
73-
data=request.data,
74-
)
75-
serializer.is_valid(raise_exception=True)
76-
except (
77-
rest_framework.serializers.ValidationError,
78-
rest_framework_simplejwt.exceptions.TokenError,
79-
) as e:
80-
if isinstance(e, rest_framework.serializers.ValidationError):
81-
return self.handle_validation_error()
82-
83-
raise rest_framework_simplejwt.exceptions.InvalidToken(str(e))
84-
85-
company = business.models.Company.objects.get(
86-
email=serializer.validated_data['email'],
87-
)
88-
company.token_version += 1
89-
company.save()
36+
class CompanySignInView(rest_framework.generics.GenericAPIView):
37+
"""
38+
Company authentication endpoint that issues new JWT tokens
39+
and bumps token_version.
40+
"""
9041

91-
refresh = rest_framework_simplejwt.tokens.RefreshToken()
92-
refresh['user_type'] = 'company'
93-
refresh['company_id'] = str(company.id)
94-
refresh['token_version'] = company.token_version
42+
serializer_class = business.serializers.CompanySignInSerializer
9543

96-
access_token = refresh.access_token
97-
access_token['user_type'] = 'company'
98-
access_token['company_id'] = str(company.id)
44+
def post(self, request, *args, **kwargs):
45+
serializer = self.get_serializer(data=request.data)
46+
serializer.is_valid(raise_exception=True)
9947

100-
response_data = {
101-
'access': str(access_token),
102-
'refresh': str(refresh),
103-
}
48+
company = serializer.validated_data['company']
49+
company = business.utils.auth.bump_company_token_version(company)
10450

10551
return rest_framework.response.Response(
106-
response_data,
52+
business.utils.tokens.generate_company_tokens(company),
10753
status=rest_framework.status.HTTP_200_OK,
10854
)
10955

11056

11157
class CompanyTokenRefreshView(rest_framework_simplejwt.views.TokenRefreshView):
58+
"""
59+
Refresh endpoint for company tokens only.
60+
"""
61+
11262
serializer_class = business.serializers.CompanyTokenRefreshSerializer
11363

11464

0 commit comments

Comments
 (0)