diff --git a/ecommerce/baskets/views.py b/ecommerce/baskets/views.py index ee32571..8ae2762 100644 --- a/ecommerce/baskets/views.py +++ b/ecommerce/baskets/views.py @@ -1,9 +1,17 @@ -from rest_framework import viewsets +from decimal import Decimal + +from django.db.transaction import atomic +from rest_framework import viewsets, status, mixins +from rest_framework.decorators import action +from rest_framework.permissions import AllowAny +from rest_framework.response import Response from baskets.filters import BasketItemFilter, BasketFilter from baskets.models import BasketItem, Basket -from baskets.serializers import BasketItemSerializer, BasketSerializer, BasketItemDetailedSerializer, BasketDetailedSerializer +from baskets.serializers import BasketItemSerializer, BasketSerializer, \ + BasketItemDetailedSerializer, BasketDetailedSerializer from core.mixins import DetailedViewSetMixin +from products.models import Product class BasketItemViewSet(DetailedViewSetMixin, viewsets.ModelViewSet): @@ -15,12 +23,77 @@ class BasketItemViewSet(DetailedViewSetMixin, viewsets.ModelViewSet): "detailed": BasketItemDetailedSerializer, } + def get_queryset(self): + queryset = super().get_queryset() + user_id = self.request.user.id + return queryset.filter(basket__customer__id=user_id) + class BasketViewSet(DetailedViewSetMixin, viewsets.ModelViewSet): + # lookup_field = 'basket' + permission_classes = () + http_method_names = ["get", "delete"] queryset = Basket.objects.all() serializer_class = BasketSerializer filterset_class = BasketFilter serializer_action_classes = { "detailed_list": BasketDetailedSerializer, "detailed": BasketDetailedSerializer, + "add_basket_item": BasketItemSerializer, + "get_basket_items": BasketItemSerializer, } + + def get_queryset(self): + queryset = super().get_queryset() + user = self.request.user + return queryset.filter(customer=user) + + @action(detail=False, methods=['get'], http_method_names=['get']) + def get_basket_items(self, request): + """ + An endpoint to show all items under the basket. + """ + user = self.request.user + basket_items_list = self.get_queryset().filter( + customer__id=user.id).first().basketitem_set.all() + serializer = self.get_serializer(basket_items_list, many=True) + return Response(serializer.data, status=status.HTTP_200_OK) + + @action(detail=True, methods=['post'], http_method_names=['post']) + @atomic() + def add_basket_item(self, request, pk=None): + """ + Adds an item to the basket and updates stocks. + """ + pk = None + user = self.request.user + + basket = self.get_object() + try: + product = Product.objects.get(id=request.data["product"]) + quantity = int(request.data["quantity"]) + price = Decimal(request.data["price"]) + except Exception as e: + return Response({'status': 'fail'}) + + stock = product.stock + if stock.quantity <= 0 or stock.quantity < quantity: + return Response({'status': 'fail'}) + + item_in_basket = BasketItem.objects.filter(basket=basket, + product=product, + price=price).first() + if item_in_basket: + item_in_basket.quantity += quantity + stock.quantity -= quantity + stock.save() + item_in_basket.save() + else: + new_item = BasketItem(basket=basket, product=product, + quantity=quantity, price=price) + stock.quantity -= quantity + stock.save() + new_item.save() + + serializer = BasketSerializer(basket) + return Response(serializer.data, status=status.HTTP_200_OK) diff --git a/ecommerce/customers/admin.py b/ecommerce/customers/admin.py index 3332089..71375fa 100644 --- a/ecommerce/customers/admin.py +++ b/ecommerce/customers/admin.py @@ -62,7 +62,7 @@ class AddressAdmin(admin.ModelAdmin): """ list_display = ("customer", "name", "city") list_filter = ("city",) - search_fields = ("line_1", "line_2", "city") + search_fields = ("line_1", "line_2", "city__name") # inlines = (CityInline, CountryInline) #TODO: Add inlines diff --git a/ecommerce/customers/serializers.py b/ecommerce/customers/serializers.py index ffad735..64904c7 100644 --- a/ecommerce/customers/serializers.py +++ b/ecommerce/customers/serializers.py @@ -1,7 +1,10 @@ +from django.contrib.auth.password_validation import validate_password +from django.core.validators import EmailValidator from django.utils.translation import gettext_lazy as _ from django.db.transaction import atomic from rest_framework import serializers from rest_framework.exceptions import ValidationError +from rest_framework.validators import UniqueValidator from customers.models import Customer, Address, City, Country @@ -68,3 +71,39 @@ class AddressDetailedSerializer(AddressSerializer): class CityDetailedSerializer(CitySerializer): country = CountrySerializer() + + + +class CustomerRegisterSerializer(serializers.ModelSerializer): + """ + To allow fast register, First Name and Last Name fields are not required. + """ + email = serializers.EmailField(required=True, + validators=[EmailValidator(), + UniqueValidator( + queryset=Customer.objects.all())]) + password = serializers.CharField(write_only=True, + required=True, + validators=[validate_password]) + password_repeat = serializers.CharField(write_only=True, required=True, ) + + class Meta: + model = Customer + fields = ("id", "email", "password", "password_repeat", "first_name", "last_name") + + def validate(self, attrs): + if attrs["password"] != attrs["password_repeat"]: + raise serializers.ValidationError(detail=_("Passwords must be identical")) + return attrs + + def create(self, validated_data): + customer = Customer.objects.create_user( + email=validated_data['email'], + first_name=validated_data.get('first_name', ""), + last_name=validated_data.get('last_name', ""), + ) + + customer.set_password(validated_data["password"]) + customer.save() + + return customer diff --git a/ecommerce/customers/views.py b/ecommerce/customers/views.py index e2d05b6..06808ec 100644 --- a/ecommerce/customers/views.py +++ b/ecommerce/customers/views.py @@ -1,5 +1,6 @@ from django.shortcuts import get_object_or_404 from rest_framework import viewsets, permissions, mixins +from rest_framework.permissions import AllowAny from rest_framework.viewsets import GenericViewSet from core.mixins import DetailedViewSetMixin @@ -8,7 +9,7 @@ from customers.models import Customer, Address, City, Country from customers.serializers import CustomerSerializer, AddressSerializer, CitySerializer, \ CountrySerializer, \ - AddressDetailedSerializer, CityDetailedSerializer, ProfileSerializer + AddressDetailedSerializer, CityDetailedSerializer, ProfileSerializer, CustomerRegisterSerializer class AdminCustomerViewSet(viewsets.ModelViewSet): @@ -65,4 +66,7 @@ def get_queryset(self): return queryset.filter(customer=user) - +class CustomerRegisterViewSet(mixins.CreateModelMixin, GenericViewSet): + queryset = Customer.objects.all() + permission_classes = (AllowAny,) + serializer_class = CustomerRegisterSerializer diff --git a/ecommerce/ecommerce/urls.py b/ecommerce/ecommerce/urls.py index 28278ef..05a8d9b 100644 --- a/ecommerce/ecommerce/urls.py +++ b/ecommerce/ecommerce/urls.py @@ -21,7 +21,7 @@ from baskets.views import BasketItemViewSet, BasketViewSet from core.views import APITokenObtainPairView from customers.views import AddressViewSet, CityViewSet, \ - CountryViewSet, AdminCustomerViewSet, MyProfileViewSet + CountryViewSet, AdminCustomerViewSet, MyProfileViewSet, CustomerRegisterViewSet from ecommerce.router import router from orders.views import OrderItemViewSet, OrderViewSet, BillingAddressViewSet, ShippingAddressViewSet, \ OrderBankAccountViewSet @@ -46,6 +46,7 @@ router.register("banks", BankViewSet) router.register("admin-products", AdminProductViewSet, basename="admin-product") router.register("admin-customers", AdminCustomerViewSet, basename="admin-customer") +router.register("register", CustomerRegisterViewSet, basename="register") urlpatterns = [