1
1
import re
2
2
3
3
import django .db .models
4
- import pycountry
5
- import rest_framework .exceptions
6
4
import rest_framework .generics
7
5
import rest_framework .permissions
8
6
import rest_framework .response
9
7
import rest_framework .serializers
10
8
import rest_framework .status
11
- import rest_framework .views
12
9
import rest_framework_simplejwt .exceptions
13
10
import rest_framework_simplejwt .tokens
14
11
import rest_framework_simplejwt .views
@@ -155,14 +152,21 @@ class CompanyPromoListView(rest_framework.generics.ListAPIView):
155
152
serializer_class = business .serializers .PromoReadOnlySerializer
156
153
pagination_class = business .pagination .CustomLimitOffsetPagination
157
154
155
+ def initial (self , request , * args , ** kwargs ):
156
+ super ().initial (request , * args , ** kwargs )
157
+
158
+ serializer = business .serializers .PromoListQuerySerializer (
159
+ data = request .query_params ,
160
+ )
161
+ serializer .is_valid (raise_exception = True )
162
+ request .validated_query_params = serializer .validated_data
163
+
158
164
def get_queryset (self ):
165
+ params = self .request .validated_query_params
166
+ countries = [c .upper () for c in params .get ('countries' , [])]
167
+ sort_by = params .get ('sort_by' )
168
+
159
169
queryset = business .models .Promo .objects .for_company (self .request .user )
160
- countries = [
161
- country .strip ()
162
- for group in self .request .query_params .getlist ('country' , [])
163
- for country in group .split (',' )
164
- if country .strip ()
165
- ]
166
170
167
171
if countries :
168
172
regex_pattern = r'(' + '|' .join (map (re .escape , countries )) + ')'
@@ -171,115 +175,9 @@ def get_queryset(self):
171
175
| django .db .models .Q (target__country__isnull = True ),
172
176
)
173
177
174
- sort_by = self .request .query_params .get ('sort_by' )
175
- if sort_by in ['active_from' , 'active_until' ]:
176
- queryset = queryset .order_by (f'-{ sort_by } ' )
177
- else :
178
- queryset = queryset .order_by ('-created_at' ) # noqa: R504
179
-
180
- return queryset # noqa: R504
181
-
182
- def list (self , request , * args , ** kwargs ):
183
- try :
184
- self .validate_query_params ()
185
- except rest_framework .exceptions .ValidationError as e :
186
- return rest_framework .response .Response (
187
- e .detail ,
188
- status = rest_framework .status .HTTP_400_BAD_REQUEST ,
189
- )
190
-
191
- return super ().list (request , * args , ** kwargs )
192
-
193
- def validate_query_params (self ):
194
- self ._validate_allowed_params ()
195
- errors = {}
196
- self ._validate_countries (errors )
197
- self ._validate_sort_by (errors )
198
- self ._validate_offset ()
199
- self ._validate_limit ()
200
- if errors :
201
- raise rest_framework .exceptions .ValidationError (errors )
202
-
203
- def _validate_allowed_params (self ):
204
- allowed_params = {'country' , 'limit' , 'offset' , 'sort_by' }
205
- unexpected_params = (
206
- set (self .request .query_params .keys ()) - allowed_params
207
- )
208
-
209
- if unexpected_params :
210
- raise rest_framework .exceptions .ValidationError ('Invalid params.' )
211
-
212
- def _validate_countries (self , errors ):
213
- countries = self .request .query_params .getlist ('country' , [])
214
- country_list = []
215
-
216
- for country_group in countries :
217
- parts = [part .strip () for part in country_group .split (',' )]
218
-
219
- if any (part == '' for part in parts ):
220
- raise rest_framework .exceptions .ValidationError (
221
- 'Invalid country format.' ,
222
- )
223
-
224
- country_list .extend (parts )
225
-
226
- country_list = [c .strip ().upper () for c in country_list if c .strip ()]
227
-
228
- invalid_countries = []
229
-
230
- for code in country_list :
231
- if len (code ) != 2 :
232
- invalid_countries .append (code )
233
- continue
234
-
235
- try :
236
- pycountry .countries .lookup (code )
237
- except LookupError :
238
- invalid_countries .append (code )
239
-
240
- if invalid_countries :
241
- errors ['country' ] = (
242
- f'Invalid country codes: { ", " .join (invalid_countries )} '
243
- )
244
-
245
- def _validate_sort_by (self , errors ):
246
- sort_by = self .request .query_params .get ('sort_by' )
247
- if sort_by and sort_by not in ['active_from' , 'active_until' ]:
248
- errors ['sort_by' ] = (
249
- 'Invalid sort_by parameter. '
250
- 'Available values: active_from, active_until'
251
- )
178
+ ordering = f'-{ sort_by } ' if sort_by else '-created_at'
252
179
253
- def _validate_offset (self ):
254
- offset = self .request .query_params .get ('offset' )
255
- if offset is not None :
256
- try :
257
- offset = int (offset )
258
- except (TypeError , ValueError ):
259
- raise rest_framework .exceptions .ValidationError (
260
- 'Invalid offset format.' ,
261
- )
262
-
263
- if offset < 0 :
264
- raise rest_framework .exceptions .ValidationError (
265
- 'Offset cannot be negative.' ,
266
- )
267
-
268
- def _validate_limit (self ):
269
- limit = self .request .query_params .get ('limit' )
270
-
271
- if limit is not None :
272
- try :
273
- limit = int (limit )
274
- except (TypeError , ValueError ):
275
- raise rest_framework .exceptions .ValidationError (
276
- 'Invalid limit format.' ,
277
- )
278
-
279
- if limit < 0 :
280
- raise rest_framework .exceptions .ValidationError (
281
- 'Limit cannot be negative.' ,
282
- )
180
+ return queryset .order_by (ordering )
283
181
284
182
285
183
class CompanyPromoDetailView (rest_framework .generics .RetrieveUpdateAPIView ):
0 commit comments