Skip to content

Commit 656dcd4

Browse files
feat: Add the ability to update the token for the company.
1 parent 0b8b083 commit 656dcd4

File tree

6 files changed

+55
-3
lines changed

6 files changed

+55
-3
lines changed

promo_code/business/serializers.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import rest_framework.exceptions
77
import rest_framework.serializers
88
import rest_framework.status
9+
import rest_framework_simplejwt.exceptions
10+
import rest_framework_simplejwt.serializers
11+
import rest_framework_simplejwt.tokens
12+
import rest_framework_simplejwt.views
913

1014

1115
class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
@@ -90,3 +94,39 @@ def validate(self, attrs):
9094
)
9195

9296
return attrs
97+
98+
99+
class CompanyTokenRefreshSerializer(
100+
rest_framework_simplejwt.serializers.TokenRefreshSerializer,
101+
):
102+
def validate(self, attrs):
103+
refresh = rest_framework_simplejwt.tokens.RefreshToken(
104+
attrs['refresh'],
105+
)
106+
user_type = refresh.payload.get('user_type', 'user')
107+
108+
if user_type != 'company':
109+
raise rest_framework_simplejwt.exceptions.InvalidToken(
110+
'This refresh endpoint is for company tokens only',
111+
)
112+
113+
company_id = refresh.payload.get('company_id')
114+
if not company_id:
115+
raise rest_framework_simplejwt.exceptions.InvalidToken(
116+
'Company ID missing in token',
117+
)
118+
119+
try:
120+
company = business_models.Company.objects.get(id=company_id)
121+
except business_models.Company.DoesNotExist:
122+
raise rest_framework_simplejwt.exceptions.InvalidToken(
123+
'Company not found',
124+
)
125+
126+
token_version = refresh.payload.get('token_version', 0)
127+
if company.token_version != token_version:
128+
raise rest_framework_simplejwt.exceptions.InvalidToken(
129+
'Token is blacklisted',
130+
)
131+
132+
return super().validate(attrs)

promo_code/business/tests/auth/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ class BaseBusinessAuthTestCase(rest_framework.test.APITestCase):
1010
def setUpTestData(cls):
1111
super().setUpTestData()
1212
cls.client = rest_framework.test.APIClient()
13+
cls.company_refresh_url = django.urls.reverse(
14+
'api-business:company-token-refresh',
15+
)
16+
cls.protected_url = django.urls.reverse('api-core:protected')
1317
cls.signup_url = django.urls.reverse('api-business:company-sign-up')
1418
cls.signin_url = django.urls.reverse('api-business:company-sign-in')
15-
cls.protected_url = django.urls.reverse('api-core:protected')
1619
cls.valid_data = {
1720
'name': 'Digital Marketing Solutions Inc.',
1821
'email': '[email protected]',

promo_code/business/urls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,9 @@
1515
business.views.CompanySignInView.as_view(),
1616
name='company-sign-in',
1717
),
18+
django.urls.path(
19+
'token/refresh',
20+
business.views.CompanyTokenRefreshView.as_view(),
21+
name='company-token-refresh',
22+
),
1823
]

promo_code/business/views.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,7 @@ def post(self, request):
9797
response_data,
9898
status=rest_framework.status.HTTP_200_OK,
9999
)
100+
101+
102+
class CompanyTokenRefreshView(rest_framework_simplejwt.views.TokenRefreshView):
103+
serializer_class = business.serializers.CompanyTokenRefreshSerializer

promo_code/user/tests/auth/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def setUpTestData(cls):
1111
super().setUpTestData()
1212
cls.client = rest_framework.test.APIClient()
1313
cls.protected_url = django.urls.reverse('api-core:protected')
14-
cls.refresh_url = django.urls.reverse('api-user:token_refresh')
14+
cls.refresh_url = django.urls.reverse('api-user:user-token-refresh')
1515
cls.signup_url = django.urls.reverse('api-user:sign-up')
1616
cls.signin_url = django.urls.reverse('api-user:sign-in')
1717

promo_code/user/urls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
django.urls.path(
2121
'token/refresh/',
2222
rest_framework_simplejwt.views.TokenRefreshView.as_view(),
23-
name='token_refresh',
23+
name='user-token-refresh',
2424
),
2525
]

0 commit comments

Comments
 (0)