1
1
import uuid
2
2
3
3
import django .contrib .auth .password_validation
4
- import django .core .exceptions
5
4
import django .core .validators
5
+ import django .db .transaction
6
6
import pycountry
7
7
import rest_framework .exceptions
8
8
import rest_framework .serializers
11
11
import rest_framework_simplejwt .tokens
12
12
13
13
import business .constants
14
+ import business .models
14
15
import business .models as business_models
16
+ import business .utils .auth
17
+ import business .utils .tokens
15
18
import business .validators
16
19
17
20
@@ -21,9 +24,9 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
21
24
write_only = True ,
22
25
required = True ,
23
26
validators = [django .contrib .auth .password_validation .validate_password ],
27
+ style = {'input_type' : 'password' },
24
28
min_length = business .constants .COMPANY_PASSWORD_MIN_LENGTH ,
25
29
max_length = business .constants .COMPANY_PASSWORD_MAX_LENGTH ,
26
- style = {'input_type' : 'password' },
27
30
)
28
31
name = rest_framework .serializers .CharField (
29
32
required = True ,
@@ -44,30 +47,18 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
44
47
45
48
class Meta :
46
49
model = business_models .Company
47
- fields = (
48
- 'id' ,
49
- 'name' ,
50
- 'email' ,
51
- 'password' ,
52
- )
50
+ fields = ('id' , 'name' , 'email' , 'password' )
53
51
52
+ @django .db .transaction .atomic
54
53
def create (self , validated_data ):
55
- try :
56
- company = business_models .Company .objects .create_company (
57
- email = validated_data ['email' ],
58
- name = validated_data ['name' ],
59
- password = validated_data ['password' ],
60
- )
61
- company .token_version += 1
62
- company .save ()
63
- return company
64
- except django .core .exceptions .ValidationError as e :
65
- raise rest_framework .serializers .ValidationError (e .messages )
54
+ company = business_models .Company .objects .create_company (
55
+ ** validated_data ,
56
+ )
66
57
58
+ return business .utils .auth .bump_company_token_version (company )
67
59
68
- class CompanySignInSerializer (
69
- rest_framework .serializers .Serializer ,
70
- ):
60
+
61
+ class CompanySignInSerializer (rest_framework .serializers .Serializer ):
71
62
email = rest_framework .serializers .EmailField (required = True )
72
63
password = rest_framework .serializers .CharField (
73
64
required = True ,
@@ -80,16 +71,15 @@ def validate(self, attrs):
80
71
password = attrs .get ('password' )
81
72
82
73
if not email or not password :
83
- raise rest_framework .exceptions .ValidationError (
84
- {'detail' : 'Both email and password are required' },
85
- code = 'required' ,
74
+ raise rest_framework .serializers .ValidationError (
75
+ 'Both email and password are required.' ,
86
76
)
87
77
88
78
try :
89
79
company = business_models .Company .objects .get (email = email )
90
80
except business_models .Company .DoesNotExist :
91
81
raise rest_framework .serializers .ValidationError (
92
- 'Invalid credentials' ,
82
+ 'Invalid credentials. ' ,
93
83
)
94
84
95
85
if not company .is_active or not company .check_password (password ):
@@ -98,53 +88,55 @@ def validate(self, attrs):
98
88
code = 'authentication_failed' ,
99
89
)
100
90
91
+ attrs ['company' ] = company
101
92
return attrs
102
93
103
94
104
95
class CompanyTokenRefreshSerializer (
105
96
rest_framework_simplejwt .serializers .TokenRefreshSerializer ,
106
97
):
107
98
def validate (self , attrs ):
99
+ attrs = super ().validate (attrs )
108
100
refresh = rest_framework_simplejwt .tokens .RefreshToken (
109
101
attrs ['refresh' ],
110
102
)
111
- user_type = refresh .payload .get ('user_type' , 'user' )
103
+ company = self .get_active_company_from_token (refresh )
104
+
105
+ company = business .utils .auth .bump_company_token_version (company )
112
106
113
- if user_type != 'company' :
107
+ return business .utils .tokens .generate_company_tokens (company )
108
+
109
+ def get_active_company_from_token (self , token ):
110
+ if token .payload .get ('user_type' ) != 'company' :
114
111
raise rest_framework_simplejwt .exceptions .InvalidToken (
115
112
'This refresh endpoint is for company tokens only' ,
116
113
)
117
114
118
- company_id = refresh .payload .get ('company_id' )
119
- if not company_id :
115
+ company_id = token .payload .get ('company_id' )
116
+ try :
117
+ company_uuid = uuid .UUID (company_id )
118
+ except (TypeError , ValueError ):
120
119
raise rest_framework_simplejwt .exceptions .InvalidToken (
121
- 'Company ID missing in token' ,
120
+ 'Invalid or missing company_id in token' ,
122
121
)
123
122
124
123
try :
125
- company = business_models .Company .objects .get (
126
- id = uuid .UUID (company_id ),
124
+ company = business .models .Company .objects .get (
125
+ id = company_uuid ,
126
+ is_active = True ,
127
127
)
128
- except business_models .Company .DoesNotExist :
128
+ except business . models .Company .DoesNotExist :
129
129
raise rest_framework_simplejwt .exceptions .InvalidToken (
130
- 'Company not found' ,
130
+ 'Company not found or inactive ' ,
131
131
)
132
132
133
- token_version = refresh .payload .get ('token_version' , 0 )
133
+ token_version = token .payload .get ('token_version' , 0 )
134
134
if company .token_version != token_version :
135
135
raise rest_framework_simplejwt .exceptions .InvalidToken (
136
136
'Token is blacklisted' ,
137
137
)
138
138
139
- new_refresh = rest_framework_simplejwt .tokens .RefreshToken ()
140
- new_refresh ['user_type' ] = 'company'
141
- new_refresh ['company_id' ] = str (company .id )
142
- new_refresh ['token_version' ] = company .token_version
143
-
144
- return {
145
- 'access' : str (new_refresh .access_token ),
146
- 'refresh' : str (new_refresh ),
147
- }
139
+ return company
148
140
149
141
150
142
class TargetSerializer (rest_framework .serializers .Serializer ):
0 commit comments