Skip to content

Commit 6f7f229

Browse files
Merge pull request #34 from RandomProgramm3r/develop
refactor(business auth): Consolidate JWT token utilities - Simplify views using `CreateAPIView` and remove redundant mixins - Add `@transaction.atomic` for atomic database operations during company creation - Extract JWT token generation into reusable `generate_company_tokens()` utility - Extract JWT token version bump into reusable `bump_company_token_version()` utility - Validate UUID format rigorously in token refresh flow - Add documentation in views - Remove redundant imports - Add and update test
2 parents 346ebf8 + 669cc9f commit 6f7f229

File tree

6 files changed

+138
-129
lines changed

6 files changed

+138
-129
lines changed

promo_code/business/serializers.py

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import uuid
22

33
import django.contrib.auth.password_validation
4-
import django.core.exceptions
54
import django.core.validators
5+
import django.db.transaction
66
import pycountry
77
import rest_framework.exceptions
88
import rest_framework.serializers
@@ -11,7 +11,10 @@
1111
import rest_framework_simplejwt.tokens
1212

1313
import business.constants
14+
import business.models
1415
import business.models as business_models
16+
import business.utils.auth
17+
import business.utils.tokens
1518
import business.validators
1619

1720

@@ -21,9 +24,9 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
2124
write_only=True,
2225
required=True,
2326
validators=[django.contrib.auth.password_validation.validate_password],
27+
style={'input_type': 'password'},
2428
min_length=business.constants.COMPANY_PASSWORD_MIN_LENGTH,
2529
max_length=business.constants.COMPANY_PASSWORD_MAX_LENGTH,
26-
style={'input_type': 'password'},
2730
)
2831
name = rest_framework.serializers.CharField(
2932
required=True,
@@ -44,30 +47,18 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
4447

4548
class Meta:
4649
model = business_models.Company
47-
fields = (
48-
'id',
49-
'name',
50-
'email',
51-
'password',
52-
)
50+
fields = ('id', 'name', 'email', 'password')
5351

52+
@django.db.transaction.atomic
5453
def create(self, validated_data):
55-
try:
56-
company = business_models.Company.objects.create_company(
57-
email=validated_data['email'],
58-
name=validated_data['name'],
59-
password=validated_data['password'],
60-
)
61-
company.token_version += 1
62-
company.save()
63-
return company
64-
except django.core.exceptions.ValidationError as e:
65-
raise rest_framework.serializers.ValidationError(e.messages)
54+
company = business_models.Company.objects.create_company(
55+
**validated_data,
56+
)
6657

58+
return business.utils.auth.bump_company_token_version(company)
6759

68-
class CompanySignInSerializer(
69-
rest_framework.serializers.Serializer,
70-
):
60+
61+
class CompanySignInSerializer(rest_framework.serializers.Serializer):
7162
email = rest_framework.serializers.EmailField(required=True)
7263
password = rest_framework.serializers.CharField(
7364
required=True,
@@ -80,16 +71,15 @@ def validate(self, attrs):
8071
password = attrs.get('password')
8172

8273
if not email or not password:
83-
raise rest_framework.exceptions.ValidationError(
84-
{'detail': 'Both email and password are required'},
85-
code='required',
74+
raise rest_framework.serializers.ValidationError(
75+
'Both email and password are required.',
8676
)
8777

8878
try:
8979
company = business_models.Company.objects.get(email=email)
9080
except business_models.Company.DoesNotExist:
9181
raise rest_framework.serializers.ValidationError(
92-
'Invalid credentials',
82+
'Invalid credentials.',
9383
)
9484

9585
if not company.is_active or not company.check_password(password):
@@ -98,53 +88,55 @@ def validate(self, attrs):
9888
code='authentication_failed',
9989
)
10090

91+
attrs['company'] = company
10192
return attrs
10293

10394

10495
class CompanyTokenRefreshSerializer(
10596
rest_framework_simplejwt.serializers.TokenRefreshSerializer,
10697
):
10798
def validate(self, attrs):
99+
attrs = super().validate(attrs)
108100
refresh = rest_framework_simplejwt.tokens.RefreshToken(
109101
attrs['refresh'],
110102
)
111-
user_type = refresh.payload.get('user_type', 'user')
103+
company = self.get_active_company_from_token(refresh)
104+
105+
company = business.utils.auth.bump_company_token_version(company)
112106

113-
if user_type != 'company':
107+
return business.utils.tokens.generate_company_tokens(company)
108+
109+
def get_active_company_from_token(self, token):
110+
if token.payload.get('user_type') != 'company':
114111
raise rest_framework_simplejwt.exceptions.InvalidToken(
115112
'This refresh endpoint is for company tokens only',
116113
)
117114

118-
company_id = refresh.payload.get('company_id')
119-
if not company_id:
115+
company_id = token.payload.get('company_id')
116+
try:
117+
company_uuid = uuid.UUID(company_id)
118+
except (TypeError, ValueError):
120119
raise rest_framework_simplejwt.exceptions.InvalidToken(
121-
'Company ID missing in token',
120+
'Invalid or missing company_id in token',
122121
)
123122

124123
try:
125-
company = business_models.Company.objects.get(
126-
id=uuid.UUID(company_id),
124+
company = business.models.Company.objects.get(
125+
id=company_uuid,
126+
is_active=True,
127127
)
128-
except business_models.Company.DoesNotExist:
128+
except business.models.Company.DoesNotExist:
129129
raise rest_framework_simplejwt.exceptions.InvalidToken(
130-
'Company not found',
130+
'Company not found or inactive',
131131
)
132132

133-
token_version = refresh.payload.get('token_version', 0)
133+
token_version = token.payload.get('token_version', 0)
134134
if company.token_version != token_version:
135135
raise rest_framework_simplejwt.exceptions.InvalidToken(
136136
'Token is blacklisted',
137137
)
138138

139-
new_refresh = rest_framework_simplejwt.tokens.RefreshToken()
140-
new_refresh['user_type'] = 'company'
141-
new_refresh['company_id'] = str(company.id)
142-
new_refresh['token_version'] = company.token_version
143-
144-
return {
145-
'access': str(new_refresh.access_token),
146-
'refresh': str(new_refresh),
147-
}
139+
return company
148140

149141

150142
class TargetSerializer(rest_framework.serializers.Serializer):

promo_code/business/tests/auth/test_tokens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_missing_company_id(self):
187187
rest_framework.status.HTTP_401_UNAUTHORIZED,
188188
)
189189
self.assertIn(
190-
'Company ID missing in token',
190+
'Invalid or missing company_id in token',
191191
str(response.content.decode()),
192192
)
193193

promo_code/business/tests/promocodes/test_permissions.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,50 @@
1111
class TestIsCompanyUserPermission(
1212
business.tests.promocodes.base.BasePromoTestCase,
1313
):
14+
@classmethod
15+
def setUpClass(cls):
16+
super().setUpClass()
17+
18+
cls.unique_payload = {
19+
'description': 'Complimentary Pudge Skin on Registration!',
20+
'target': {},
21+
'max_count': 1,
22+
'mode': 'UNIQUE',
23+
'active_from': '2030-08-08',
24+
'promo_unique': ['dota-arena', 'coda-core', 'warcraft3'],
25+
}
26+
1427
def setUp(self):
1528
self.factory = rest_framework.test.APIRequestFactory()
1629
self.permission = business.permissions.IsCompanyUser()
1730
get_user_model = django.contrib.auth.get_user_model
1831
self.regular_user = get_user_model().objects.create_user(
1932
name='regular',
20-
password='testpass123',
33+
password='SecurePass123!',
2134
surname='adadioa',
2235
2336
)
24-
self.company_user = business.models.Company.objects.create_company(
25-
password='testpass123',
26-
name='Test Company',
27-
37+
38+
def create_promo(self, token, payload):
39+
self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + token)
40+
response = self.client.post(
41+
self.promo_create_url,
42+
payload,
43+
format='json',
44+
)
45+
self.assertEqual(
46+
response.status_code,
47+
rest_framework.status.HTTP_201_CREATED,
2848
)
49+
return response.data['id']
2950

3051
def tearDown(self):
3152
business.models.Company.objects.all().delete()
3253
user.models.User.objects.all().delete()
3354

3455
def test_has_permission_for_company_user(self):
3556
request = self.factory.get(self.promo_create_url)
36-
request.user = self.company_user
57+
request.user = self.company1
3758
self.assertTrue(self.permission.has_permission(request, None))
3859

3960
def test_has_permission_for_regular_user(self):
@@ -45,3 +66,16 @@ def test_has_permission_for_anonymous_user(self):
4566
request = self.factory.get(self.promo_create_url)
4667
request.user = None
4768
self.assertFalse(self.permission.has_permission(request, None))
69+
70+
def test_has_permission_to_foreign_promo(self):
71+
promo_id = self.create_promo(self.company2_token, self.unique_payload)
72+
self.client.credentials(
73+
HTTP_AUTHORIZATION='Bearer ' + self.company1_token,
74+
)
75+
url = self.promo_detail_url(promo_id)
76+
patch_payload = {'description': '100% Cashback'}
77+
response = self.client.patch(url, patch_payload, format='json')
78+
self.assertEqual(
79+
response.status_code,
80+
rest_framework.status.HTTP_403_FORBIDDEN,
81+
)

promo_code/business/utils/auth.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import business.models
2+
3+
4+
def bump_company_token_version(company):
5+
"""
6+
Increment token_version, save it, and return the fresh instance.
7+
"""
8+
company = business.models.Company.objects.select_for_update().get(
9+
id=company.id,
10+
)
11+
company.token_version += 1
12+
company.save(update_fields=['token_version'])
13+
return company

promo_code/business/utils/tokens.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import rest_framework_simplejwt.tokens
2+
3+
4+
def generate_company_tokens(company):
5+
"""
6+
Generate JWT tokens for a company.
7+
"""
8+
refresh = rest_framework_simplejwt.tokens.RefreshToken()
9+
refresh['user_type'] = 'company'
10+
refresh['company_id'] = str(company.id)
11+
refresh['token_version'] = company.token_version
12+
13+
access = refresh.access_token
14+
access['user_type'] = 'company'
15+
access['company_id'] = str(company.id)
16+
17+
return {
18+
'access': str(access),
19+
'refresh': str(refresh),
20+
}

0 commit comments

Comments
 (0)