Skip to content

Commit 48327b9

Browse files
refactor: Add some changes to registration and authentication, add a test.
- Now the token version is increased during registration. - Now the token version is not visible to the user and is left only on the server side. - Add a test that checks the invalidation of the token received during registration after authentication. - Simplify validate method in SignInSerializer.
2 parents 4ce2d22 + 0aaf118 commit 48327b9

File tree

3 files changed

+88
-71
lines changed

3 files changed

+88
-71
lines changed

promo_code/user/serializers.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def create(self, validated_data):
6666
other=validated_data['other'],
6767
password=validated_data['password'],
6868
)
69+
user.token_version += 1
70+
user.save()
6971
return user
7072
except django.core.exceptions.ValidationError as e:
7173
raise rest_framework.serializers.ValidationError(e.messages)
@@ -80,13 +82,26 @@ class SignInSerializer(
8082
write_only=True,
8183
)
8284

83-
def validate(self, data):
84-
email = data.get('email')
85-
password = data.get('password')
85+
def validate(self, attrs):
86+
user = self.authenticate_user(attrs)
87+
88+
self.update_token_version(user)
89+
90+
data = super().validate(attrs)
91+
92+
refresh = rest_framework_simplejwt.tokens.RefreshToken(data['refresh'])
93+
94+
self.invalidate_previous_tokens(user, refresh['jti'])
95+
96+
return data
97+
98+
def authenticate_user(self, attrs):
99+
email = attrs.get('email')
100+
password = attrs.get('password')
86101

87102
if not email or not password:
88-
raise rest_framework.serializers.ValidationError(
89-
{'status': 'error', 'message': 'Both fields are required.'},
103+
raise rest_framework.exceptions.ValidationError(
104+
{'detail': 'Both email and password are required'},
90105
code='required',
91106
)
92107

@@ -95,55 +110,26 @@ def validate(self, data):
95110
email=email,
96111
password=password,
97112
)
98-
if not user:
99-
raise rest_framework.exceptions.AuthenticationFailed(
100-
{'status': 'error', 'message': 'Invalid email or password.'},
101-
code='authorization',
102-
)
103113

104-
authenticate_kwargs = {
105-
self.username_field: data[self.username_field],
106-
'password': data['password'],
107-
}
108-
try:
109-
authenticate_kwargs['request'] = self.context['request']
110-
except KeyError:
111-
pass
112-
113-
self.user = django.contrib.auth.authenticate(**authenticate_kwargs)
114-
115-
if not getattr(self.user, 'is_active', None):
114+
if not user or not user.is_active:
116115
raise rest_framework.exceptions.AuthenticationFailed(
117-
self.error_messages['no_active_account'],
118-
'no_active_account',
116+
{'detail': 'Invalid credentials or inactive account'},
117+
code='authentication_failed',
119118
)
120119

121-
self.user.token_version += 1
122-
self.user.save()
120+
return user
123121

124-
refresh = self.get_token(self.user)
125-
data = {
126-
'refresh': str(refresh),
127-
'access': str(refresh.access_token),
128-
}
129-
130-
current_jti = refresh['jti']
131-
132-
tokens_qs = tb_models.OutstandingToken.objects.filter(
133-
user=self.user,
134-
)
135-
136-
outstanding_tokens = tokens_qs.exclude(jti=current_jti)
122+
def invalidate_previous_tokens(self, user, current_jti):
123+
outstanding_tokens = tb_models.OutstandingToken.objects.filter(
124+
user=user,
125+
).exclude(jti=current_jti)
137126

138127
for token in outstanding_tokens:
139-
(
140-
tb_models.BlacklistedToken.objects.get_or_create(
141-
token=token,
142-
)
143-
)
128+
tb_models.BlacklistedToken.objects.get_or_create(token=token)
144129

145-
data['token_version'] = self.user.token_version
146-
return data
130+
def update_token_version(self, user):
131+
user.token_version += 1
132+
user.save()
147133

148134
def get_token(self, user):
149135
token = super().get_token(user)

promo_code/user/tests.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test_valid_registration(self):
324324
response.status_code,
325325
rest_framework.status.HTTP_200_OK,
326326
)
327-
self.assertIn('token', response.data)
327+
self.assertIn('access', response.data)
328328
self.assertTrue(
329329
user.models.User.objects.filter(
330330
@@ -391,7 +391,7 @@ def test_signin_success(self):
391391

392392
class JWTTests(rest_framework.test.APITestCase):
393393
def setUp(self):
394-
394+
self.signup_url = django.urls.reverse('api-user:sign-up')
395395
self.signin_url = django.urls.reverse('api-user:sign-in')
396396
self.protected_url = django.urls.reverse('api-core:protected')
397397
self.refresh_url = django.urls.reverse('api-user:token_refresh')
@@ -428,13 +428,54 @@ def test_access_protected_view_with_valid_token(self):
428428
self.assertEqual(response.status_code, 200)
429429
self.assertEqual(response.data['status'], 'request was permitted')
430430

431-
def test_refresh_token_invalidation_after_new_login(self):
431+
def test_registration_token_invalid_after_login(self):
432+
data = {
433+
'email': '[email protected]',
434+
'password': 'StrongPass123!cd',
435+
'name': 'John',
436+
'surname': 'Doe',
437+
'other': {'age': 22, 'country': 'us'},
438+
}
439+
response = self.client.post(
440+
self.signup_url,
441+
data,
442+
format='json',
443+
)
444+
reg_access_token = response.data['access']
445+
446+
self.client.credentials(
447+
HTTP_AUTHORIZATION=f'Bearer {reg_access_token}',
448+
)
449+
response = self.client.get(self.protected_url)
450+
self.assertEqual(response.status_code, 200)
451+
452+
login_data = {'email': data['email'], 'password': data['password']}
453+
response = self.client.post(
454+
self.signin_url,
455+
login_data,
456+
format='json',
457+
)
458+
login_access_token = response.data['access']
459+
460+
self.client.credentials(
461+
HTTP_AUTHORIZATION=f'Bearer {reg_access_token}',
462+
)
463+
response = self.client.get(self.protected_url)
464+
self.assertEqual(response.status_code, 401)
465+
466+
self.client.credentials(
467+
HTTP_AUTHORIZATION=f'Bearer {login_access_token}',
468+
)
469+
response = self.client.get(self.protected_url)
470+
self.assertEqual(response.status_code, 200)
432471

472+
def test_refresh_token_invalidation_after_new_login(self):
433473
first_login_response = self.client.post(
434474
self.signin_url,
435475
self.user_data,
436476
format='json',
437477
)
478+
438479
refresh_token_v1 = first_login_response.data['refresh']
439480

440481
second_login_response = self.client.post(
@@ -493,21 +534,3 @@ def test_blacklist_storage(self):
493534
(tb_models.OutstandingToken.objects.count()),
494535
2,
495536
)
496-
497-
def test_token_version_increment(self):
498-
response1 = self.client.post(
499-
self.signin_url,
500-
self.user_data,
501-
format='json',
502-
)
503-
self.assertEqual(response1.data['token_version'], 1)
504-
505-
response2 = self.client.post(
506-
self.signin_url,
507-
self.user_data,
508-
format='json',
509-
)
510-
self.assertEqual(response2.data['token_version'], 2)
511-
512-
user_ = user.models.User.objects.get(email=self.user_data['email'])
513-
self.assertEqual(user_.token_version, 2)

promo_code/user/views.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,13 @@ def create(self, request, *args, **kwargs):
3535
return self.handle_validation_error()
3636

3737
user = serializer.save()
38+
3839
refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(user)
40+
refresh['token_version'] = user.token_version
41+
access_token = refresh.access_token
42+
3943
return rest_framework.response.Response(
40-
{'token': str(refresh.access_token)},
44+
{'access': str(access_token), 'refresh': str(refresh)},
4145
status=rest_framework.status.HTTP_200_OK,
4246
)
4347

@@ -49,11 +53,10 @@ class SignInView(
4953
serializer_class = user.serializers.SignInSerializer
5054

5155
def post(self, request, *args, **kwargs):
52-
serializer = self.get_serializer(data=request.data)
5356

5457
try:
58+
serializer = self.get_serializer(data=request.data)
5559
serializer.is_valid(raise_exception=True)
56-
response = super().post(request, *args, **kwargs)
5760
except (
5861
rest_framework.serializers.ValidationError,
5962
rest_framework_simplejwt.exceptions.TokenError,
@@ -63,7 +66,12 @@ def post(self, request, *args, **kwargs):
6366

6467
raise rest_framework_simplejwt.exceptions.InvalidToken(str(e))
6568

69+
response_data = {
70+
'access': serializer.validated_data['access'],
71+
'refresh': serializer.validated_data['refresh'],
72+
}
73+
6674
return rest_framework.response.Response(
67-
response,
75+
response_data,
6876
status=rest_framework.status.HTTP_200_OK,
6977
)

0 commit comments

Comments
 (0)