Skip to content

Commit 09472fb

Browse files
refactor: Refactor query parameter handling and pagination
- Added PromoListQuerySerializer to validate query parameters. - CompanyPromoListView: moved query parameter checking to PromoListQuerySerializer. - CustomLimitOffsetPagination: removed unreachable max_limit check, dropped try/except in get_limit since serializer now guarantees numeric input.
1 parent dba090d commit 09472fb

File tree

3 files changed

+138
-127
lines changed

3 files changed

+138
-127
lines changed

promo_code/business/pagination.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,15 @@ class CustomLimitOffsetPagination(
99
max_limit = 100
1010

1111
def get_limit(self, request):
12-
param_limit = request.query_params.get(self.limit_query_param)
13-
if param_limit is not None:
14-
limit = int(param_limit)
12+
raw_limit = request.query_params.get(self.limit_query_param)
1513

16-
if limit == 0:
17-
return 0
14+
if raw_limit is None:
15+
return self.default_limit
1816

19-
if self.max_limit:
20-
return min(limit, self.max_limit)
17+
limit = int(raw_limit)
2118

22-
return limit
23-
24-
return self.default_limit
19+
# Allow 0, otherwise cut by max_limit
20+
return 0 if limit == 0 else min(limit, self.max_limit)
2521

2622
def get_paginated_response(self, data):
2723
response = rest_framework.response.Response(data)

promo_code/business/serializers.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,123 @@ def to_representation(self, instance):
282282
return data
283283

284284

285+
class PromoListQuerySerializer(rest_framework.serializers.Serializer):
286+
"""
287+
Serializer for validating query parameters of promo list requests.
288+
"""
289+
290+
limit = rest_framework.serializers.CharField(
291+
required=False,
292+
allow_blank=True,
293+
)
294+
offset = rest_framework.serializers.CharField(
295+
required=False,
296+
allow_blank=True,
297+
)
298+
sort_by = rest_framework.serializers.ChoiceField(
299+
choices=['active_from', 'active_until'],
300+
required=False,
301+
)
302+
country = rest_framework.serializers.CharField(
303+
required=False,
304+
allow_blank=True,
305+
)
306+
307+
_allowed_params = None
308+
309+
def get_allowed_params(self):
310+
if self._allowed_params is None:
311+
self._allowed_params = set(self.fields.keys())
312+
return self._allowed_params
313+
314+
def validate(self, attrs):
315+
query_params = self.initial_data
316+
allowed_params = self.get_allowed_params()
317+
318+
unexpected_params = set(query_params.keys()) - allowed_params
319+
if unexpected_params:
320+
raise rest_framework.exceptions.ValidationError('Invalid params.')
321+
322+
field_errors = {}
323+
324+
attrs = self._validate_int_field('limit', attrs, field_errors)
325+
attrs = self._validate_int_field('offset', attrs, field_errors)
326+
327+
self._validate_country(query_params, attrs, field_errors)
328+
329+
if field_errors:
330+
raise rest_framework.exceptions.ValidationError(field_errors)
331+
332+
return attrs
333+
334+
def _validate_int_field(self, field_name, attrs, field_errors):
335+
value_str = self.initial_data.get(field_name)
336+
if value_str is None:
337+
return attrs
338+
339+
if value_str == '':
340+
raise rest_framework.exceptions.ValidationError(
341+
f'Invalid {field_name} format.',
342+
)
343+
344+
try:
345+
value_int = int(value_str)
346+
if value_int < 0:
347+
raise rest_framework.exceptions.ValidationError(
348+
f'{field_name.capitalize()} cannot be negative.',
349+
)
350+
attrs[field_name] = value_int
351+
except (ValueError, TypeError):
352+
raise rest_framework.exceptions.ValidationError(
353+
f'Invalid {field_name} format.',
354+
)
355+
356+
return attrs
357+
358+
def _validate_country(self, query_params, attrs, field_errors):
359+
countries_raw = query_params.getlist('country', [])
360+
361+
if '' in countries_raw:
362+
raise rest_framework.exceptions.ValidationError(
363+
'Invalid country format.',
364+
)
365+
366+
country_codes = []
367+
invalid_codes = []
368+
369+
for country_group in countries_raw:
370+
if not country_group.strip():
371+
continue
372+
373+
parts = [part.strip() for part in country_group.split(',')]
374+
375+
if '' in parts:
376+
raise rest_framework.exceptions.ValidationError(
377+
'Invalid country format.',
378+
)
379+
380+
country_codes.extend(parts)
381+
382+
country_codes_upper = [c.upper() for c in country_codes]
383+
384+
for code in country_codes_upper:
385+
if len(code) != 2:
386+
invalid_codes.append(code)
387+
continue
388+
try:
389+
pycountry.countries.lookup(code)
390+
except LookupError:
391+
invalid_codes.append(code)
392+
393+
if invalid_codes:
394+
field_errors['country'] = (
395+
f'Invalid country codes: {", ".join(invalid_codes)}'
396+
)
397+
398+
attrs['countries'] = country_codes
399+
attrs.pop('country', None)
400+
401+
285402
class PromoReadOnlySerializer(rest_framework.serializers.ModelSerializer):
286403
promo_id = rest_framework.serializers.UUIDField(
287404
source='id',

promo_code/business/views.py

Lines changed: 15 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import re
22

33
import django.db.models
4-
import pycountry
5-
import rest_framework.exceptions
64
import rest_framework.generics
75
import rest_framework.permissions
86
import rest_framework.response
97
import rest_framework.serializers
108
import rest_framework.status
11-
import rest_framework.views
129
import rest_framework_simplejwt.exceptions
1310
import rest_framework_simplejwt.tokens
1411
import rest_framework_simplejwt.views
@@ -155,14 +152,21 @@ class CompanyPromoListView(rest_framework.generics.ListAPIView):
155152
serializer_class = business.serializers.PromoReadOnlySerializer
156153
pagination_class = business.pagination.CustomLimitOffsetPagination
157154

155+
def initial(self, request, *args, **kwargs):
156+
super().initial(request, *args, **kwargs)
157+
158+
serializer = business.serializers.PromoListQuerySerializer(
159+
data=request.query_params,
160+
)
161+
serializer.is_valid(raise_exception=True)
162+
request.validated_query_params = serializer.validated_data
163+
158164
def get_queryset(self):
165+
params = self.request.validated_query_params
166+
countries = [c.upper() for c in params.get('countries', [])]
167+
sort_by = params.get('sort_by')
168+
159169
queryset = business.models.Promo.objects.for_company(self.request.user)
160-
countries = [
161-
country.strip()
162-
for group in self.request.query_params.getlist('country', [])
163-
for country in group.split(',')
164-
if country.strip()
165-
]
166170

167171
if countries:
168172
regex_pattern = r'(' + '|'.join(map(re.escape, countries)) + ')'
@@ -171,115 +175,9 @@ def get_queryset(self):
171175
| django.db.models.Q(target__country__isnull=True),
172176
)
173177

174-
sort_by = self.request.query_params.get('sort_by')
175-
if sort_by in ['active_from', 'active_until']:
176-
queryset = queryset.order_by(f'-{sort_by}')
177-
else:
178-
queryset = queryset.order_by('-created_at') # noqa: R504
179-
180-
return queryset # noqa: R504
181-
182-
def list(self, request, *args, **kwargs):
183-
try:
184-
self.validate_query_params()
185-
except rest_framework.exceptions.ValidationError as e:
186-
return rest_framework.response.Response(
187-
e.detail,
188-
status=rest_framework.status.HTTP_400_BAD_REQUEST,
189-
)
190-
191-
return super().list(request, *args, **kwargs)
192-
193-
def validate_query_params(self):
194-
self._validate_allowed_params()
195-
errors = {}
196-
self._validate_countries(errors)
197-
self._validate_sort_by(errors)
198-
self._validate_offset()
199-
self._validate_limit()
200-
if errors:
201-
raise rest_framework.exceptions.ValidationError(errors)
202-
203-
def _validate_allowed_params(self):
204-
allowed_params = {'country', 'limit', 'offset', 'sort_by'}
205-
unexpected_params = (
206-
set(self.request.query_params.keys()) - allowed_params
207-
)
208-
209-
if unexpected_params:
210-
raise rest_framework.exceptions.ValidationError('Invalid params.')
211-
212-
def _validate_countries(self, errors):
213-
countries = self.request.query_params.getlist('country', [])
214-
country_list = []
215-
216-
for country_group in countries:
217-
parts = [part.strip() for part in country_group.split(',')]
218-
219-
if any(part == '' for part in parts):
220-
raise rest_framework.exceptions.ValidationError(
221-
'Invalid country format.',
222-
)
223-
224-
country_list.extend(parts)
225-
226-
country_list = [c.strip().upper() for c in country_list if c.strip()]
227-
228-
invalid_countries = []
229-
230-
for code in country_list:
231-
if len(code) != 2:
232-
invalid_countries.append(code)
233-
continue
234-
235-
try:
236-
pycountry.countries.lookup(code)
237-
except LookupError:
238-
invalid_countries.append(code)
239-
240-
if invalid_countries:
241-
errors['country'] = (
242-
f'Invalid country codes: {", ".join(invalid_countries)}'
243-
)
244-
245-
def _validate_sort_by(self, errors):
246-
sort_by = self.request.query_params.get('sort_by')
247-
if sort_by and sort_by not in ['active_from', 'active_until']:
248-
errors['sort_by'] = (
249-
'Invalid sort_by parameter. '
250-
'Available values: active_from, active_until'
251-
)
178+
ordering = f'-{sort_by}' if sort_by else '-created_at'
252179

253-
def _validate_offset(self):
254-
offset = self.request.query_params.get('offset')
255-
if offset is not None:
256-
try:
257-
offset = int(offset)
258-
except (TypeError, ValueError):
259-
raise rest_framework.exceptions.ValidationError(
260-
'Invalid offset format.',
261-
)
262-
263-
if offset < 0:
264-
raise rest_framework.exceptions.ValidationError(
265-
'Offset cannot be negative.',
266-
)
267-
268-
def _validate_limit(self):
269-
limit = self.request.query_params.get('limit')
270-
271-
if limit is not None:
272-
try:
273-
limit = int(limit)
274-
except (TypeError, ValueError):
275-
raise rest_framework.exceptions.ValidationError(
276-
'Invalid limit format.',
277-
)
278-
279-
if limit < 0:
280-
raise rest_framework.exceptions.ValidationError(
281-
'Limit cannot be negative.',
282-
)
180+
return queryset.order_by(ordering)
283181

284182

285183
class CompanyPromoDetailView(rest_framework.generics.RetrieveUpdateAPIView):

0 commit comments

Comments
 (0)