Skip to content

Commit 2f624d0

Browse files
Merge pull request #3 from RandomProgramm3r/develop
refactor: Rework JWT authentication and add tests.
2 parents b00ffca + e9267f6 commit 2f624d0

File tree

13 files changed

+281
-59
lines changed

13 files changed

+281
-59
lines changed

promo_code/core/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66

77
class StaticURLTests(django.test.TestCase):
88
def test_ping_endpoint(self):
9-
response = self.client.get(django.urls.reverse('core:ping'))
9+
response = self.client.get(django.urls.reverse('api-core:ping'))
1010
self.assertEqual(response.status_code, http.HTTPStatus.OK)

promo_code/core/urls.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import core.views
22
import django.urls
33

4-
app_name = 'core'
4+
app_name = 'api-core'
55

66

77
urlpatterns = [
@@ -10,4 +10,9 @@
1010
core.views.PingView.as_view(),
1111
name='ping',
1212
),
13+
django.urls.path(
14+
'protected-endpoint/',
15+
core.views.MyProtectedView.as_view(),
16+
name='protected',
17+
),
1318
]

promo_code/core/views.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
import django.http
22
import django.views
3+
import rest_framework.permissions
4+
import rest_framework.response
5+
import rest_framework.views
36

47

58
class PingView(django.views.View):
69
def get(self, request, *args, **kwargs):
710
return django.http.HttpResponse('PROOOOOOOOOOOOOOOOOD', status=200)
11+
12+
13+
class MyProtectedView(rest_framework.views.APIView):
14+
permission_classes = [rest_framework.permissions.IsAuthenticated]
15+
16+
def get(self, request, format=None):
17+
content = {'status': 'request was permitted'}
18+
return rest_framework.response.Response(content)

promo_code/promo_code/settings.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,47 +48,55 @@ def load_bool(name, default):
4848
AUTH_USER_MODEL = 'user.User'
4949

5050
REST_FRAMEWORK = {
51-
'DEFAULT_RENDERER_CLASSES': ('rest_framework.renderers.JSONRenderer',),
5251
'DEFAULT_AUTHENTICATION_CLASSES': [
53-
'user.authentication.CustomJWTAuthentication',
52+
'rest_framework_simplejwt.authentication.JWTAuthentication',
5453
],
5554
}
5655

5756
SIMPLE_JWT = {
58-
'ACCESS_TOKEN_LIFETIME': datetime.timedelta(hours=1),
57+
'ACCESS_TOKEN_LIFETIME': datetime.timedelta(minutes=60),
5958
'REFRESH_TOKEN_LIFETIME': datetime.timedelta(days=1),
6059
'ROTATE_REFRESH_TOKENS': True,
6160
'BLACKLIST_AFTER_ROTATION': True,
62-
'UPDATE_LAST_LOGIN': False, # !
63-
#
61+
'UPDATE_LAST_LOGIN': False,
6462
'ALGORITHM': 'HS256',
65-
'SIGNING_KEY': SECRET_KEY,
66-
'VERIFYING_KEY': None,
63+
'VERIFYING_KEY': '',
6764
'AUDIENCE': None,
6865
'ISSUER': None,
6966
'JSON_ENCODER': None,
7067
'JWK_URL': None,
7168
'LEEWAY': 0,
72-
#
7369
'AUTH_HEADER_TYPES': ('Bearer',),
7470
'AUTH_HEADER_NAME': 'HTTP_AUTHORIZATION',
7571
'USER_ID_FIELD': 'id',
7672
'USER_ID_CLAIM': 'user_id',
7773
'USER_AUTHENTICATION_RULE': (
7874
'rest_framework_simplejwt.authentication'
79-
'.default_user_authentication_rule',
75+
'.default_user_authentication_rule'
8076
),
81-
#
77+
'AUTH_TOKEN_CLASSES': ('rest_framework_simplejwt.tokens.AccessToken',),
8278
'TOKEN_TYPE_CLAIM': 'token_type',
8379
'TOKEN_USER_CLASS': 'rest_framework_simplejwt.models.TokenUser',
84-
#
8580
'JTI_CLAIM': 'jti',
86-
#
8781
'SLIDING_TOKEN_REFRESH_EXP_CLAIM': 'refresh_exp',
8882
'SLIDING_TOKEN_LIFETIME': datetime.timedelta(minutes=5),
8983
'SLIDING_TOKEN_REFRESH_LIFETIME': datetime.timedelta(days=1),
90-
#
91-
'ACCESS_TOKEN_CLASS': 'user.tokens.CustomAccessToken',
84+
'TOKEN_OBTAIN_SERIALIZER': 'user.serializers.SignInSerializer',
85+
'TOKEN_REFRESH_SERIALIZER': (
86+
'rest_framework_simplejwt.serializers.TokenRefreshSerializer'
87+
),
88+
'TOKEN_VERIFY_SERIALIZER': (
89+
'rest_framework_simplejwt.serializers.TokenVerifySerializer'
90+
),
91+
'TOKEN_BLACKLIST_SERIALIZER': (
92+
'rest_framework_simplejwt.serializers.TokenBlacklistSerializer'
93+
),
94+
'SLIDING_TOKEN_OBTAIN_SERIALIZER': (
95+
'rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer'
96+
),
97+
'SLIDING_TOKEN_REFRESH_SERIALIZER': (
98+
'rest_framework_simplejwt.serializers.TokenRefreshSlidingSerializer'
99+
),
92100
}
93101

94102
MIDDLEWARE = [
@@ -99,6 +107,7 @@ def load_bool(name, default):
99107
'django.contrib.auth.middleware.AuthenticationMiddleware',
100108
'django.contrib.messages.middleware.MessageMiddleware',
101109
'django.middleware.clickjacking.XFrameOptionsMiddleware',
110+
'user.middleware.TokenVersionMiddleware',
102111
]
103112

104113
ROOT_URLCONF = 'promo_code.urls'

promo_code/user/authentication.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

promo_code/user/middleware.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import django.http
2+
import rest_framework_simplejwt.authentication
3+
4+
5+
class TokenVersionMiddleware:
6+
def __init__(self, get_response):
7+
self.get_response = get_response
8+
9+
def __call__(self, request):
10+
auth = rest_framework_simplejwt.authentication.JWTAuthentication()
11+
auth_result = auth.authenticate(request)
12+
13+
if auth_result is None:
14+
return self.get_response(request)
15+
16+
user, token = auth_result
17+
if user:
18+
token_version = token.payload.get('token_version', 0)
19+
if token_version != user.token_version:
20+
return django.http.JsonResponse(
21+
{'error': 'Token invalid'},
22+
status=401,
23+
)
24+
25+
return self.get_response(request)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Generated by Django 5.2b1 on 2025-03-14 19:46
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('user', '0001_initial'),
10+
]
11+
12+
operations = [
13+
migrations.AddField(
14+
model_name='user',
15+
name='token_version',
16+
field=models.IntegerField(default=0),
17+
),
18+
]

promo_code/user/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class User(
4646
)
4747
other = django.db.models.JSONField(default=dict)
4848

49+
token_version = django.db.models.IntegerField(default=0)
50+
4951
is_active = django.db.models.BooleanField(default=True)
5052
is_staff = django.db.models.BooleanField(default=False)
5153
last_login = django.db.models.DateTimeField(null=True, blank=True)

promo_code/user/serializers.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import rest_framework.exceptions
55
import rest_framework.serializers
66
import rest_framework.status
7+
import rest_framework_simplejwt.serializers
8+
import rest_framework_simplejwt.token_blacklist.models as tb_models
79
import rest_framework_simplejwt.tokens
810

911
import user.models as user_models
@@ -69,7 +71,9 @@ def create(self, validated_data):
6971
raise rest_framework.serializers.ValidationError(e.messages)
7072

7173

72-
class SignInSerializer(rest_framework.serializers.Serializer):
74+
class SignInSerializer(
75+
rest_framework_simplejwt.serializers.TokenObtainPairSerializer,
76+
):
7377
email = rest_framework.serializers.EmailField(required=True)
7478
password = rest_framework.serializers.CharField(
7579
required=True,
@@ -97,10 +101,51 @@ def validate(self, data):
97101
code='authorization',
98102
)
99103

100-
data['user'] = user
104+
authenticate_kwargs = {
105+
self.username_field: data[self.username_field],
106+
'password': data['password'],
107+
}
108+
try:
109+
authenticate_kwargs['request'] = self.context['request']
110+
except KeyError:
111+
pass
112+
113+
self.user = django.contrib.auth.authenticate(**authenticate_kwargs)
114+
115+
if not getattr(self.user, 'is_active', None):
116+
raise rest_framework.exceptions.AuthenticationFailed(
117+
self.error_messages['no_active_account'],
118+
'no_active_account',
119+
)
120+
121+
self.user.token_version += 1
122+
self.user.save()
123+
124+
refresh = self.get_token(self.user)
125+
data = {
126+
'refresh': str(refresh),
127+
'access': str(refresh.access_token),
128+
}
129+
130+
current_jti = refresh['jti']
131+
132+
tokens_qs = tb_models.OutstandingToken.objects.filter(
133+
user=self.user,
134+
)
135+
136+
outstanding_tokens = tokens_qs.exclude(jti=current_jti)
137+
138+
for token in outstanding_tokens:
139+
(
140+
tb_models.BlacklistedToken.objects.get_or_create(
141+
token=token,
142+
)
143+
)
144+
145+
data['token_version'] = self.user.token_version
101146
return data
102147

103-
def get_token(self):
104-
user = self.validated_data['user']
105-
refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(user)
106-
return {'token': str(refresh.access_token)}
148+
def get_token(self, user):
149+
token = super().get_token(user)
150+
token['token_version'] = user.token_version
151+
return token

0 commit comments

Comments
 (0)