Skip to content

Commit 5ad3fbf

Browse files
test: Add tests for updating the refresh token for the company.
1 parent 656dcd4 commit 5ad3fbf

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed

promo_code/business/tests/auth/test_tokens.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import business.tests.auth.base
33
import rest_framework.status
44
import rest_framework.test
5+
import rest_framework_simplejwt.tokens
6+
7+
import user.models
58

69

710
class JWTTests(business.tests.auth.base.BaseBusinessAuthTestCase):
@@ -82,3 +85,193 @@ def test_registration_token_invalid_after_login(self):
8285
response.status_code,
8386
rest_framework.status.HTTP_200_OK,
8487
)
88+
89+
90+
class TestCompanyTokenRefresh(
91+
business.tests.auth.base.BaseBusinessAuthTestCase,
92+
):
93+
def setUp(self):
94+
super().setUp()
95+
96+
self.company = business.models.Company.objects.create_company(
97+
name='Digital Marketing Solutions Inc.',
98+
99+
password='SuperStrongPassword2000!',
100+
token_version=1,
101+
)
102+
103+
self.company_data = {
104+
'email': '[email protected]',
105+
'password': 'SuperStrongPassword2000!',
106+
}
107+
108+
self.company_refresh = rest_framework_simplejwt.tokens.RefreshToken()
109+
self.company_refresh.payload.update(
110+
{
111+
'user_type': 'company',
112+
'company_id': self.company.id,
113+
'token_version': self.company.token_version,
114+
},
115+
)
116+
117+
self.user = user.models.User.objects.create_user(
118+
119+
name='Steve',
120+
surname='Jobs',
121+
password='SuperStrongPassword2000!',
122+
other={'age': 23, 'country': 'gb'},
123+
)
124+
self.user_refresh = (
125+
rest_framework_simplejwt.tokens.RefreshToken.for_user(self.user)
126+
)
127+
self.user_refresh.payload['user_type'] = 'user'
128+
129+
def test_successful_company_token_refresh(self):
130+
response = self.client.post(
131+
self.company_refresh_url,
132+
{'refresh': str(self.company_refresh)},
133+
)
134+
135+
self.assertEqual(
136+
response.status_code,
137+
rest_framework.status.HTTP_200_OK,
138+
)
139+
self.assertIn('access', response.data)
140+
self.assertIn('refresh', response.data)
141+
142+
self.assertNotEqual(self.company_refresh, response.data['refresh'])
143+
144+
def test_reject_user_tokens(self):
145+
response = self.client.post(
146+
self.company_refresh_url,
147+
{'refresh': str(self.user_refresh)},
148+
)
149+
150+
self.assertEqual(
151+
response.status_code,
152+
rest_framework.status.HTTP_401_UNAUTHORIZED,
153+
)
154+
self.assertIn(
155+
'This refresh endpoint is for company tokens only',
156+
str(response.content),
157+
)
158+
159+
def test_token_version_mismatch(self):
160+
self.company.token_version = 2
161+
self.company.save()
162+
163+
response = self.client.post(
164+
self.company_refresh_url,
165+
{'refresh': str(self.company_refresh)},
166+
)
167+
168+
self.assertEqual(
169+
response.status_code,
170+
rest_framework.status.HTTP_401_UNAUTHORIZED,
171+
)
172+
self.assertIn('Token is blacklisted', str(response.content))
173+
174+
def test_missing_company_id(self):
175+
invalid_refresh = rest_framework_simplejwt.tokens.RefreshToken()
176+
invalid_refresh.payload.update(
177+
{'user_type': 'company', 'token_version': 1},
178+
)
179+
180+
response = self.client.post(
181+
self.company_refresh_url,
182+
{'refresh': str(invalid_refresh)},
183+
)
184+
185+
self.assertEqual(
186+
response.status_code,
187+
rest_framework.status.HTTP_401_UNAUTHORIZED,
188+
)
189+
self.assertIn(
190+
'Company ID missing in token',
191+
str(response.content.decode()),
192+
)
193+
194+
def test_company_not_found(self):
195+
invalid_refresh = rest_framework_simplejwt.tokens.RefreshToken()
196+
invalid_refresh.payload.update(
197+
{'user_type': 'company', 'company_id': 999, 'token_version': 1},
198+
)
199+
200+
response = self.client.post(
201+
self.company_refresh_url,
202+
{'refresh': str(invalid_refresh)},
203+
)
204+
205+
self.assertEqual(
206+
response.status_code,
207+
rest_framework.status.HTTP_401_UNAUTHORIZED,
208+
)
209+
self.assertIn('Company not found', str(response.content))
210+
211+
def test_refresh_token_invalidation_after_new_login(self):
212+
first_login_response = self.client.post(
213+
self.signin_url,
214+
self.company_data,
215+
format='json',
216+
)
217+
refresh_token_v1 = first_login_response.data['refresh']
218+
219+
second_login_response = self.client.post(
220+
self.signin_url,
221+
self.company_data,
222+
format='json',
223+
)
224+
refresh_token_v2 = second_login_response.data['refresh']
225+
226+
refresh_response_v1 = self.client.post(
227+
self.company_refresh_url,
228+
{'refresh': refresh_token_v1},
229+
format='json',
230+
)
231+
self.assertEqual(
232+
refresh_response_v1.status_code,
233+
rest_framework.status.HTTP_401_UNAUTHORIZED,
234+
)
235+
self.assertEqual(refresh_response_v1.data['code'], 'token_not_valid')
236+
self.assertEqual(
237+
str(refresh_response_v1.data['detail']),
238+
'Token is blacklisted',
239+
)
240+
241+
refresh_response_v2 = self.client.post(
242+
self.company_refresh_url,
243+
{'refresh': refresh_token_v2},
244+
format='json',
245+
)
246+
self.assertEqual(
247+
refresh_response_v2.status_code,
248+
rest_framework.status.HTTP_200_OK,
249+
)
250+
self.assertIn('access', refresh_response_v2.data)
251+
252+
self.client.credentials(
253+
HTTP_AUTHORIZATION='Bearer ' + first_login_response.data['access'],
254+
)
255+
protected_response = self.client.get(self.protected_url)
256+
self.assertEqual(
257+
protected_response.status_code,
258+
rest_framework.status.HTTP_401_UNAUTHORIZED,
259+
)
260+
261+
def test_default_user_type_handling(self):
262+
refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(
263+
self.user,
264+
)
265+
response = self.client.post(
266+
self.company_refresh_url,
267+
{'refresh': str(refresh)},
268+
)
269+
270+
self.assertEqual(
271+
response.status_code,
272+
rest_framework.status.HTTP_401_UNAUTHORIZED,
273+
)
274+
self.assertIn(
275+
'This refresh endpoint is for company tokens only',
276+
str(response.content),
277+
)

0 commit comments

Comments
 (0)