diff --git a/longclaw/basket/api.py b/longclaw/basket/api.py index 5e6a1dc..d2cee41 100644 --- a/longclaw/basket/api.py +++ b/longclaw/basket/api.py @@ -6,6 +6,9 @@ from longclaw.basket.serializers import BasketItemSerializer from longclaw.basket import utils from longclaw.utils import ProductVariant +from .signals import basket_modified + + class BasketViewSet(viewsets.ModelViewSet): """ Viewset for interacting with a sessions 'basket' - @@ -44,6 +47,8 @@ class BasketViewSet(viewsets.ModelViewSet): serializer = BasketItemSerializer(self.get_queryset(request), many=True) response = Response(data=serializer.data, status=status.HTTP_201_CREATED) + + basket_modified.send(sender=BasketItem, basket_id=bid) else: response = Response( @@ -66,24 +71,33 @@ class BasketViewSet(viewsets.ModelViewSet): serializer = BasketItemSerializer(self.get_queryset(request), many=True) response = Response(data=serializer.data, status=status.HTTP_200_OK) + + basket_modified.send(sender=BasketItem, basket_id=bid) + return response def destroy(self, request, variant_id=None): """ Remove an item from the basket """ + bid = utils.basket_id(request) + variant = ProductVariant.objects.get(id=variant_id) quantity = int(request.data.get("quantity", 1)) try: item = BasketItem.objects.get( - basket_id=utils.basket_id(request), variant=variant) + basket_id=bid, variant=variant) item.decrease_quantity(quantity) except BasketItem.DoesNotExist: pass serializer = BasketItemSerializer(self.get_queryset(request), many=True) - return Response(data=serializer.data, + response = Response(data=serializer.data, status=status.HTTP_200_OK) + + basket_modified.send(sender=BasketItem, basket_id=bid) + + return response @action(detail=False, methods=['get']) def total_items(self, request): diff --git a/longclaw/basket/signals.py b/longclaw/basket/signals.py new file mode 100644 index 0000000..9ecb6d9 --- /dev/null +++ b/longclaw/basket/signals.py @@ -0,0 +1,3 @@ +import django.dispatch + +basket_modified = django.dispatch.Signal(providing_args=['basket_id']) diff --git a/longclaw/basket/tests.py b/longclaw/basket/tests.py index aff5bb0..a6c4ac4 100644 --- a/longclaw/basket/tests.py +++ b/longclaw/basket/tests.py @@ -1,3 +1,4 @@ +import mock from django.test.client import RequestFactory from django.test import TestCase try: @@ -7,10 +8,13 @@ except ImportError: from django.core.management import call_command from django.utils.six import StringIO -from longclaw.tests.utils import LongclawTestCase, BasketItemFactory, ProductVariantFactory +from longclaw.tests.utils import LongclawTestCase, BasketItemFactory, ProductVariantFactory, catch_signal from longclaw.basket.utils import basket_id from longclaw.basket.templatetags import basket_tags from longclaw.basket.context_processors import stripe_key +from longclaw.basket.models import BasketItem + +from .signals import basket_modified class CommandTests(TestCase): @@ -83,6 +87,59 @@ class BasketTest(LongclawTestCase): self.assertIn('STRIPE_KEY', stripe_key(None)) +class BasketModifiedSignalTest(LongclawTestCase): + """Round trip API tests + """ + def setUp(self): + """Create a basket with things in it + """ + request = RequestFactory().get('/') + request.session = {} + bid = basket_id(request) + self.item = BasketItemFactory(basket_id=bid) + BasketItemFactory(basket_id=bid) + + def test_create_basket_item(self): + """ + Test creating a new basket item + """ + with catch_signal(basket_modified) as handler: + variant = ProductVariantFactory() + self.post_test({'variant_id': variant.id}, 'longclaw_basket_list') + + handler.assert_called_once_with( + basket_id=mock.ANY, # TODO: CHECK CORRECT BASKET ID IS SENT + sender=BasketItem, + signal=basket_modified, + ) + + def test_increase_basket_item(self): + """ + Test increasing quantity of basket item + """ + with catch_signal(basket_modified) as handler: + self.post_test({'variant_id': self.item.variant.id}, 'longclaw_basket_list') + + handler.assert_called_once_with( + basket_id=mock.ANY, # TODO: CHECK CORRECT BASKET ID IS SENT + sender=BasketItem, + signal=basket_modified, + ) + + def test_remove_item(self): + """ + Test removing an item from the basket + """ + with catch_signal(basket_modified) as handler: + self.del_test('longclaw_basket_detail', {'variant_id': self.item.variant.id}) + + handler.assert_called_once_with( + basket_id=mock.ANY, # TODO: CHECK CORRECT BASKET ID IS SENT + sender=BasketItem, + signal=basket_modified, + ) + + class BasketModelTest(TestCase): def setUp(self): diff --git a/longclaw/checkout/tests.py b/longclaw/checkout/tests.py index dc9fda4..91979ae 100644 --- a/longclaw/checkout/tests.py +++ b/longclaw/checkout/tests.py @@ -1,5 +1,8 @@ +import uuid +from django.utils.encoding import force_text from django.test import TestCase from django.test.client import RequestFactory +from wagtail.core.models import Site try: from django.urls import reverse_lazy except ImportError: @@ -12,6 +15,7 @@ from longclaw.tests.utils import ( CountryFactory, OrderFactory ) +from longclaw.shipping.models import ShippingRate from longclaw.checkout.utils import create_order from longclaw.checkout.forms import CheckoutForm from longclaw.checkout.views import CheckoutView @@ -80,6 +84,73 @@ class CheckoutApiTest(LongclawTestCase): self.get_test('longclaw_checkout_token') +class CheckoutApiShippingTest(LongclawTestCase): + def setUp(self): + self.shipping_address = AddressFactory() + self.billing_address = AddressFactory() + self.email = "test@test.com" + self.request = RequestFactory().get('/') + self.request.session = {} + self.request.site = Site.find_for_request(self.request) + self.basket_id = basket_id(self.request) + BasketItemFactory(basket_id=self.basket_id) + + def test_create_order_with_basket_shipping_option(self): + amount = 11 + rate = ShippingRate.objects.create( + name=force_text(uuid.uuid4()), + rate=amount, + carrier=force_text(uuid.uuid4()), + description=force_text(uuid.uuid4()), + basket_id=self.basket_id, + ) + order = create_order( + self.email, + self.request, + shipping_address=self.shipping_address, + billing_address=self.billing_address, + shipping_option=rate.name, + ) + self.assertEqual(order.shipping_rate, amount) + + def test_create_order_with_address_shipping_option(self): + amount = 12 + rate = ShippingRate.objects.create( + name=force_text(uuid.uuid4()), + rate=amount, + carrier=force_text(uuid.uuid4()), + description=force_text(uuid.uuid4()), + destination=self.shipping_address, + ) + order = create_order( + self.email, + self.request, + shipping_address=self.shipping_address, + billing_address=self.billing_address, + shipping_option=rate.name, + ) + self.assertEqual(order.shipping_rate, amount) + + def test_create_order_with_address_and_basket_shipping_option(self): + amount = 13 + rate = ShippingRate.objects.create( + name=force_text(uuid.uuid4()), + rate=amount, + carrier=force_text(uuid.uuid4()), + description=force_text(uuid.uuid4()), + destination=self.shipping_address, + basket_id=self.basket_id, + ) + order = create_order( + self.email, + self.request, + shipping_address=self.shipping_address, + billing_address=self.billing_address, + shipping_option=rate.name, + ) + self.assertEqual(order.shipping_rate, amount) + + class CheckoutTest(TestCase): def test_checkout_form(self): diff --git a/longclaw/checkout/utils.py b/longclaw/checkout/utils.py index 3158384..f17be96 100644 --- a/longclaw/checkout/utils.py +++ b/longclaw/checkout/utils.py @@ -22,7 +22,7 @@ def create_order(email, """ Create an order from a basket and customer infomation """ - basket_items, _ = get_basket_items(request) + basket_items, current_basket_id = get_basket_items(request) if addresses: # Longclaw < 0.2 used 'shipping_name', longclaw > 0.2 uses a consistent # prefix (shipping_address_xxxx) @@ -68,7 +68,10 @@ def create_order(email, shipping_rate = get_shipping_cost( site_settings, shipping_address.country.pk, - shipping_option)['rate'] + shipping_option, + basket_id=current_basket_id, + destination=shipping_address, + )['rate'] else: shipping_rate = Decimal(0) diff --git a/longclaw/configuration/migrations/0002_configuration_shipping_origin.py b/longclaw/configuration/migrations/0002_configuration_shipping_origin.py new file mode 100644 index 0000000..dc8f42b --- /dev/null +++ b/longclaw/configuration/migrations/0002_configuration_shipping_origin.py @@ -0,0 +1,20 @@ +# Generated by Django 2.1.7 on 2019-03-22 22:30 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('shipping', '0003_auto_20190322_1429'), + ('configuration', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='configuration', + name='shipping_origin', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.PROTECT, to='shipping.Address'), + ), + ] diff --git a/longclaw/configuration/models.py b/longclaw/configuration/models.py index fd914cd..9e7e216 100644 --- a/longclaw/configuration/models.py +++ b/longclaw/configuration/models.py @@ -3,8 +3,11 @@ Admin confiurable settings for longclaw apps """ from wagtail.contrib.settings.models import BaseSetting, register_setting from wagtail.admin.edit_handlers import FieldPanel +from wagtail.snippets.edit_handlers import SnippetChooserPanel from django.db import models +from longclaw.shipping.models import Address + @register_setting class Configuration(BaseSetting): @@ -24,6 +27,8 @@ class Configuration(BaseSetting): help_text=('Whether to enable default shipping.' ' This essentially means you ship to all countries,' ' not only those with configured shipping rates')) + + shipping_origin = models.ForeignKey(Address, blank=True, null=True, on_delete=models.PROTECT) currency_html_code = models.CharField( max_length=12, @@ -40,6 +45,8 @@ class Configuration(BaseSetting): FieldPanel('default_shipping_rate'), FieldPanel('default_shipping_carrier'), FieldPanel('default_shipping_enabled'), + SnippetChooserPanel('shipping_origin'), + FieldPanel('currency_html_code'), FieldPanel('currency_html_code'), FieldPanel('currency') ) diff --git a/longclaw/shipping/api.py b/longclaw/shipping/api.py index 41b3dbf..6a2a516 100644 --- a/longclaw/shipping/api.py +++ b/longclaw/shipping/api.py @@ -1,8 +1,13 @@ +from django.db.models import Q from rest_framework.decorators import api_view, permission_classes from rest_framework import permissions, status, viewsets from rest_framework.response import Response from longclaw.shipping import models, utils, serializers from longclaw.configuration.models import Configuration +from longclaw.basket.utils import basket_id + +from .models import ShippingRateProcessor +from .signals import address_modified class AddressViewSet(viewsets.ModelViewSet): """ @@ -10,6 +15,46 @@ class AddressViewSet(viewsets.ModelViewSet): """ queryset = models.Address.objects.all() serializer_class = serializers.AddressSerializer + + def perform_create(self, serializer): + output = super().perform_create(serializer) + instance = serializer.instance + address_modified.send(sender=models.Address, instance=instance) + + def perform_update(self, serializer): + output = super().perform_update(serializer) + instance = serializer.instance + address_modified.send(sender=models.Address, instance=instance) + + def perform_destroy(self, instance): + output = super().perform_destroy(instance) + address_modified.send(sender=models.Address, instance=instance) + + +def get_shipping_cost_kwargs(request, country=None): + country_code = request.query_params.get('country_code', None) + if country: + if country_code is not None: + raise utils.InvalidShippingCountry("Cannot specify country and country_code") + country_code = country + + destination = request.query_params.get('destination', None) + if destination: + try: + destination = models.Address.objects.get(pk=destination) + except models.Address.DoesNotExist: + raise utils.InvalidShippingDestination("Address not found") + elif not country_code: + raise utils.InvalidShippingCountry("No country code supplied") + + if not country_code: + country_code = destination.country.pk + + bid = basket_id(request) + option = request.query_params.get('shipping_rate_name', 'standard') + settings = Configuration.for_site(request.site) + + return dict(country_code=country_code, destination=destination, basket_id=bid, settings=settings, name=option) @api_view(['GET']) @@ -20,25 +65,26 @@ def shipping_cost(request): fallback to the default shipping cost if it has been enabled in the app settings """ + status_code = status.HTTP_400_BAD_REQUEST try: - code = request.query_params.get('country_code') - except AttributeError: - return Response(data={"message": "No country code supplied"}, - status=status.HTTP_400_BAD_REQUEST) + kwargs = get_shipping_cost_kwargs(request) + except (utils.InvalidShippingCountry, utils.InvalidShippingDestination) as e: + data = {'message': e.message} + else: + try: + data = utils.get_shipping_cost(**kwargs) + except utils.InvalidShippingRate: + data = { + "message": "Shipping option {} is invalid".format(kwargs['name']) + } + except utils.InvalidShippingCountry: + data = { + "message": "Shipping to {} is not available".format(kwargs['country_code']) + } + else: + status_code = status.HTTP_200_OK - option = request.query_params.get('shipping_rate_name', 'standard') - try: - settings = Configuration.for_site(request.site) - data = utils.get_shipping_cost(settings, code, option) - response = Response(data=data, status=status.HTTP_200_OK) - except utils.InvalidShippingRate: - response = Response(data={"message": "Shipping option {} is invalid".format(option)}, - status=status.HTTP_400_BAD_REQUEST) - except utils.InvalidShippingCountry: - response = Response(data={"message": "Shipping to {} is not available".format(code)}, - status=status.HTTP_400_BAD_REQUEST) - - return response + return Response(data=data, status=status_code) @api_view(["GET"]) @@ -52,11 +98,39 @@ def shipping_countries(request): @api_view(["GET"]) @permission_classes([permissions.AllowAny]) -def shipping_options(request, country): +def shipping_options(request, country=None): """ Get the shipping options for a given country """ - qrs = models.ShippingRate.objects.filter(countries__in=[country]) + try: + kwargs = get_shipping_cost_kwargs(request, country=country) + except (utils.InvalidShippingCountry, utils.InvalidShippingDestination) as e: + return Response(data={'message': e.message}, status=status.HTTP_400_BAD_REQUEST) + + country_code = kwargs['country_code'] + settings = kwargs['settings'] + bid = kwargs['basket_id'] + destination = kwargs['destination'] + + processors = ShippingRateProcessor.objects.filter(countries__in=[country_code]) + if processors: + if not destination: + return Response( + data={ + "message": "Destination address is required for rates to {}.".format(country_code) + }, + status=status.HTTP_400_BAD_REQUEST + ) + for processor in processors: + processor.get_rates(settings=settings, basket_id=bid, destination=destination) + + q = Q(countries__in=[country_code]) | Q(basket_id=bid, destination=None) + + if destination: + q.add(Q(destination=destination, basket_id=''), Q.OR) + q.add(Q(destination=destination, basket_id=bid), Q.OR) + + qrs = models.ShippingRate.objects.filter(q) serializer = serializers.ShippingRateSerializer(qrs, many=True) return Response( data=serializer.data, diff --git a/longclaw/shipping/migrations/0002_auto_20190318_1237.py b/longclaw/shipping/migrations/0002_auto_20190318_1237.py new file mode 100644 index 0000000..11ddeaa --- /dev/null +++ b/longclaw/shipping/migrations/0002_auto_20190318_1237.py @@ -0,0 +1,24 @@ +# Generated by Django 2.1.7 on 2019-03-18 17:37 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('shipping', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='shippingrate', + name='destination', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.PROTECT, to='shipping.Address'), + ), + migrations.AddField( + model_name='shippingrate', + name='basket_id', + field=models.CharField(blank=True, db_index=True, max_length=32), + ), + ] diff --git a/longclaw/shipping/migrations/0003_auto_20190322_1429.py b/longclaw/shipping/migrations/0003_auto_20190322_1429.py new file mode 100644 index 0000000..44f019b --- /dev/null +++ b/longclaw/shipping/migrations/0003_auto_20190322_1429.py @@ -0,0 +1,32 @@ +# Generated by Django 2.1.7 on 2019-03-22 19:29 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('contenttypes', '0002_remove_content_type_name'), + ('shipping', '0002_auto_20190318_1237'), + ] + + operations = [ + migrations.CreateModel( + name='ShippingRateProcessor', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('countries', models.ManyToManyField(to='shipping.Country')), + ('polymorphic_ctype', models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_shipping.shippingrateprocessor_set+', to='contenttypes.ContentType')), + ], + options={ + 'base_manager_name': 'objects', + 'abstract': False, + }, + ), + migrations.AddField( + model_name='shippingrate', + name='processor', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.PROTECT, to='shipping.ShippingRateProcessor'), + ), + ] diff --git a/longclaw/shipping/models/__init__.py b/longclaw/shipping/models/__init__.py new file mode 100644 index 0000000..0db7149 --- /dev/null +++ b/longclaw/shipping/models/__init__.py @@ -0,0 +1,3 @@ +from .locations import * +from .processors import * +from .rates import * diff --git a/longclaw/shipping/models.py b/longclaw/shipping/models/locations.py similarity index 74% rename from longclaw/shipping/models.py rename to longclaw/shipping/models/locations.py index 5e58d25..86c09b6 100644 --- a/longclaw/shipping/models.py +++ b/longclaw/shipping/models/locations.py @@ -25,31 +25,6 @@ class Address(models.Model): def __str__(self): return "{}, {}, {}".format(self.name, self.city, self.country) -class ShippingRate(models.Model): - """ - An individual shipping rate. This can be applied to - multiple countries. - """ - name = models.CharField( - max_length=32, - unique=True, - help_text="Unique name to refer to this shipping rate by" - ) - rate = models.DecimalField(max_digits=12, decimal_places=2) - carrier = models.CharField(max_length=64) - description = models.CharField(max_length=128) - countries = models.ManyToManyField('shipping.Country') - - panels = [ - FieldPanel('name'), - FieldPanel('rate'), - FieldPanel('carrier'), - FieldPanel('description'), - FieldPanel('countries') - ] - - def __str__(self): - return self.name class Country(models.Model): """ diff --git a/longclaw/shipping/models/processors.py b/longclaw/shipping/models/processors.py new file mode 100644 index 0000000..2f62a45 --- /dev/null +++ b/longclaw/shipping/models/processors.py @@ -0,0 +1,67 @@ +import json +import hashlib + +from django.utils.encoding import force_bytes, force_text +from django.core.cache import cache +from django.core.serializers.json import DjangoJSONEncoder +from django.db import models, transaction +from django.dispatch import receiver + +from longclaw.basket.models import BasketItem +from longclaw.basket.signals import basket_modified +from polymorphic.models import PolymorphicModel +from wagtail.admin.edit_handlers import FieldPanel + +from ..serializers.locations import AddressSerializer +from ..signals import address_modified + + +class ShippingRateProcessor(PolymorphicModel): + countries = models.ManyToManyField('shipping.Country') + + rates_cache_timeout = 300 + def get_rates(self, settings=None, basket_id=None, destination=None): + kwargs = dict(settings=settings, basket_id=basket_id, destination=destination) + key = self.get_rates_cache_key(**kwargs) + rates = cache.get(key) + if rates is None: + with transaction.atomic(): + rates = self.process_rates(**kwargs) + if rates is not None: + cache.set(key, rates, self.rates_cache_timeout) + return rates + + def get_rates_cache_key(self, **kwargs): + from longclaw.basket.serializers import BasketItemSerializer + + settings = kwargs['settings'] + origin = settings.shipping_origin + destination = kwargs['destination'] + basket_id = kwargs['basket_id'] + + items = BasketItem.objects.filter(basket_id=basket_id) + serialized_items = BasketItemSerializer(items, many=True) + + serialized_origin = AddressSerializer(origin) or None + serialized_destination = AddressSerializer(destination) or None + + data = { + "items": serialized_items.data, + "origin": serialized_origin.data, + "destination": serialized_destination.data, + } + + raw_key = json.dumps( + data, + sort_keys=True, + indent=4, + separators=(',', ': '), + cls=DjangoJSONEncoder, + ) + + hashed_key = hashlib.sha1(force_bytes(raw_key)).hexdigest() + + return force_text(hashed_key) + + def process_rates(self, **kwargs): + raise NotImplementedError() diff --git a/longclaw/shipping/models/rates.py b/longclaw/shipping/models/rates.py new file mode 100644 index 0000000..5c8a24e --- /dev/null +++ b/longclaw/shipping/models/rates.py @@ -0,0 +1,47 @@ +from django.db import models +from django.dispatch import receiver + +from longclaw.basket.signals import basket_modified +from wagtail.admin.edit_handlers import FieldPanel + +from ..signals import address_modified + + +class ShippingRate(models.Model): + """ + An individual shipping rate. This can be applied to + multiple countries. + """ + name = models.CharField( + max_length=32, + unique=True, + help_text="Unique name to refer to this shipping rate by" + ) + rate = models.DecimalField(max_digits=12, decimal_places=2) + carrier = models.CharField(max_length=64) + description = models.CharField(max_length=128) + countries = models.ManyToManyField('shipping.Country') + basket_id = models.CharField(blank=True, db_index=True, max_length=32) + destination = models.ForeignKey('shipping.Address', blank=True, null=True, on_delete=models.PROTECT) + processor = models.ForeignKey('shipping.ShippingRateProcessor', blank=True, null=True, on_delete=models.PROTECT) + + panels = [ + FieldPanel('name'), + FieldPanel('rate'), + FieldPanel('carrier'), + FieldPanel('description'), + FieldPanel('countries') + ] + + def __str__(self): + return self.name + + +@receiver(address_modified) +def clear_address_rates(sender, instance, **kwargs): + ShippingRate.objects.filter(destination=instance).delete() + + +@receiver(basket_modified) +def clear_basket_rates(sender, basket_id, **kwargs): + ShippingRate.objects.filter(basket_id=basket_id).delete() diff --git a/longclaw/shipping/serializers/__init__.py b/longclaw/shipping/serializers/__init__.py new file mode 100644 index 0000000..db5e329 --- /dev/null +++ b/longclaw/shipping/serializers/__init__.py @@ -0,0 +1,2 @@ +from .locations import * +from .rates import * diff --git a/longclaw/shipping/serializers.py b/longclaw/shipping/serializers/locations.py similarity index 64% rename from longclaw/shipping/serializers.py rename to longclaw/shipping/serializers/locations.py index a1fc66a..3364b3e 100644 --- a/longclaw/shipping/serializers.py +++ b/longclaw/shipping/serializers/locations.py @@ -1,6 +1,6 @@ from rest_framework import serializers -from longclaw.shipping.models import Address, ShippingRate, Country +from longclaw.shipping.models.locations import Address, Country class AddressSerializer(serializers.ModelSerializer): country = serializers.PrimaryKeyRelatedField(queryset=Country.objects.all()) @@ -8,11 +8,6 @@ class AddressSerializer(serializers.ModelSerializer): model = Address fields = "__all__" -class ShippingRateSerializer(serializers.ModelSerializer): - class Meta: - model = ShippingRate - fields = "__all__" - class CountrySerializer(serializers.ModelSerializer): class Meta: model = Country diff --git a/longclaw/shipping/serializers/rates.py b/longclaw/shipping/serializers/rates.py new file mode 100644 index 0000000..89ea388 --- /dev/null +++ b/longclaw/shipping/serializers/rates.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from longclaw.shipping.models.rates import ShippingRate + +class ShippingRateSerializer(serializers.ModelSerializer): + class Meta: + model = ShippingRate + fields = "__all__" diff --git a/longclaw/shipping/signals.py b/longclaw/shipping/signals.py new file mode 100644 index 0000000..feb97ad --- /dev/null +++ b/longclaw/shipping/signals.py @@ -0,0 +1,3 @@ +import django.dispatch + +address_modified = django.dispatch.Signal(providing_args=['instance']) diff --git a/longclaw/shipping/tests.py b/longclaw/shipping/tests.py index de90747..3bd3888 100644 --- a/longclaw/shipping/tests.py +++ b/longclaw/shipping/tests.py @@ -1,14 +1,133 @@ +import uuid +import mock +from decimal import Decimal + +from django.utils.encoding import force_text from django.test import TestCase +from django.test.client import RequestFactory from django.forms.models import model_to_dict -from longclaw.tests.utils import LongclawTestCase, AddressFactory, CountryFactory, ShippingRateFactory +from longclaw.tests.utils import LongclawTestCase, AddressFactory, CountryFactory, ShippingRateFactory, BasketItemFactory, catch_signal +from longclaw.shipping.api import get_shipping_cost_kwargs from longclaw.shipping.forms import AddressForm -from longclaw.shipping.utils import get_shipping_cost +from longclaw.shipping.models import Address, Country +from longclaw.shipping.utils import get_shipping_cost, InvalidShippingCountry from longclaw.shipping.templatetags import longclawshipping_tags from longclaw.configuration.models import Configuration +from longclaw.basket.signals import basket_modified +from longclaw.basket.utils import basket_id +from rest_framework import status +from rest_framework.views import APIView +from wagtail.core.models import Site + +from .models import Address, ShippingRate, clear_basket_rates, clear_address_rates, ShippingRateProcessor +from .signals import address_modified +from .serializers import AddressSerializer, ShippingRateSerializer + + +def upgrade_to_api_request(request): + # This extra step is required until https://github.com/encode/django-rest-framework/issues/6488 + # is resolved + class DummyGenericViewsetLike(APIView): + lookup_field = 'test' + + def reverse_action(view, *args, **kwargs): + self.assertEqual(kwargs['kwargs']['test'], 1) + return '/example/' + + response = DummyGenericViewsetLike.as_view()(request) + view = response.renderer_context['view'] + view.request.site = Site.objects.first() + return view.request + class ShippingTests(LongclawTestCase): def setUp(self): self.country = CountryFactory() + + def test_get_shipping_cost_kwargs_country_and_code(self): + request = RequestFactory().get('/', { 'country_code': 'US' }) + api_request = upgrade_to_api_request(request) + with self.assertRaises(InvalidShippingCountry): + get_shipping_cost_kwargs(api_request, country=self.country.pk) + + def test_get_shipping_cost_kwargs_destination_does_not_exist(self): + non_existant_pk = 2147483647 + self.assertFalse(Address.objects.filter(pk=non_existant_pk).exists()) + request = RequestFactory().get('/', { 'country_code': 'US', 'destination': str(non_existant_pk) }) + api_request = upgrade_to_api_request(request) + with self.assertRaises(InvalidShippingCountry): + get_shipping_cost_kwargs(api_request, country=self.country.pk) + + def test_get_shipping_cost_kwargs_no_country_or_code(self): + request = RequestFactory().get('/') + api_request = upgrade_to_api_request(request) + with self.assertRaises(InvalidShippingCountry): + get_shipping_cost_kwargs(api_request) + + def test_get_shipping_cost_kwargs_only_country_code(self): + request = RequestFactory().get('/', { 'country_code': 'US' }) + api_request = upgrade_to_api_request(request) + result = get_shipping_cost_kwargs(api_request) + self.assertEqual(result['country_code'], 'US') + self.assertEqual(result['destination'], None) + self.assertEqual(result['basket_id'], basket_id(api_request)) + self.assertEqual(result['settings'], Configuration.for_site(api_request.site)) + self.assertEqual(result['name'], 'standard') + + def test_get_shipping_cost_kwargs_country_code_and_shipping_rate_name(self): + request = RequestFactory().get('/', { 'country_code': 'US', 'shipping_rate_name': 'foo' }) + api_request = upgrade_to_api_request(request) + result = get_shipping_cost_kwargs(api_request) + self.assertEqual(result['country_code'], 'US') + self.assertEqual(result['destination'], None) + self.assertEqual(result['basket_id'], basket_id(api_request)) + self.assertEqual(result['settings'], Configuration.for_site(api_request.site)) + self.assertEqual(result['name'], 'foo') + + def test_get_shipping_cost_kwargs_only_country(self): + request = RequestFactory().get('/') + api_request = upgrade_to_api_request(request) + result = get_shipping_cost_kwargs(api_request, country=self.country.pk) + self.assertEqual(result['country_code'], self.country.pk) + self.assertEqual(result['destination'], None) + self.assertEqual(result['basket_id'], basket_id(api_request)) + self.assertEqual(result['settings'], Configuration.for_site(api_request.site)) + self.assertEqual(result['name'], 'standard') + + def test_get_shipping_cost_kwargs_only_country_known_iso(self): + request = RequestFactory().get('/') + api_request = upgrade_to_api_request(request) + country = Country.objects.create(iso='ZZ', name_official='foo', name='foo') + result = get_shipping_cost_kwargs(api_request, country=country.pk) + self.assertEqual(result['country_code'], 'ZZ') + self.assertEqual(result['destination'], None) + self.assertEqual(result['basket_id'], basket_id(api_request)) + self.assertEqual(result['settings'], Configuration.for_site(api_request.site)) + self.assertEqual(result['name'], 'standard') + + def test_get_shipping_cost_kwargs_with_destination(self): + destination = AddressFactory() + request = RequestFactory().get('/', { 'destination': destination.pk }) + api_request = upgrade_to_api_request(request) + result = get_shipping_cost_kwargs(api_request) + self.assertEqual(result['country_code'], destination.country.pk) + self.assertEqual(result['destination'], destination) + self.assertEqual(result['basket_id'], basket_id(api_request)) + self.assertEqual(result['settings'], Configuration.for_site(api_request.site)) + self.assertEqual(result['name'], 'standard') + + def test_get_shipping_cost_kwargs_with_destination_and_country_code(self): + destination = AddressFactory() + request = RequestFactory().get('/', { 'destination': destination.pk, 'country_code': '11' }) + api_request = upgrade_to_api_request(request) + result = get_shipping_cost_kwargs(api_request) + self.assertNotEqual(str(destination.country.pk), '11') + self.assertEqual(result['country_code'], '11') + self.assertEqual(result['destination'], destination) + self.assertEqual(result['basket_id'], basket_id(api_request)) + self.assertEqual(result['settings'], Configuration.for_site(api_request.site)) + self.assertEqual(result['name'], 'standard') + def test_create_address(self): """ Test creating an address object via the api @@ -39,6 +158,195 @@ class ShippingTests(LongclawTestCase): self.assertEqual(ls.default_shipping_rate, result["rate"]) +class ShippingBasketTests(LongclawTestCase): + def setUp(self): + """Create a basket with things in it + """ + request = RequestFactory().get('/') + request.session = {} + self.bid = bid = basket_id(request) + self.item = BasketItemFactory(basket_id=bid) + BasketItemFactory(basket_id=bid) + + self.address = address = AddressFactory() + + self.rate1 = ShippingRate.objects.create( + name='98d17c43-7e20-42bd-b603-a4c83c829c5a', + rate=99, + carrier='8717ca67-4691-4dff-96ec-c43cccd15241', + description='313037e1-644a-4570-808a-f9ba82ecfb34', + basket_id=bid, + ) + + self.rate2 = ShippingRate.objects.create( + name='8e721550-594c-482b-b512-54dc1744dff8', + rate=97, + carrier='4f4cca35-1a7a-47ec-ab38-a9918e0c04af', + description='eacb446d-eb17-4ea7-82c1-ac2f62a53a7d', + basket_id=bid, + destination=address, + ) + + self.rate3 = ShippingRate.objects.create( + name='72991859-dc0b-463e-821a-bf8b04aaed2c', + rate=95, + carrier='0aa3c318-b045-4a96-a456-69b4cc71d46a', + description='78b03c47-b20f-4f91-8161-47340367fb34', + destination=address, + ) + + def test_basket_rate(self): + # this tests that we get a basket rate that is just tied to the basket and nothing else + # (i.e. this basket qualifies for free shipping or something like that) + result = get_shipping_cost(Configuration(), name='98d17c43-7e20-42bd-b603-a4c83c829c5a', basket_id=self.bid) + self.assertEqual(result["rate"], 99) + self.assertEqual(result["description"], '313037e1-644a-4570-808a-f9ba82ecfb34') + + def test_basket_address_rate(self): + # this tests that we get a rate tied to a particular basket and a particular address + result = get_shipping_cost( + Configuration(), + name='8e721550-594c-482b-b512-54dc1744dff8', + basket_id=self.bid, + destination=self.address, + ) + self.assertEqual(result["rate"], 97) + self.assertEqual(result["description"], 'eacb446d-eb17-4ea7-82c1-ac2f62a53a7d') + + def test_address_rate(self): + # this tests that we get a rate tied to a particular address + result = get_shipping_cost( + Configuration(), + name='72991859-dc0b-463e-821a-bf8b04aaed2c', + destination=self.address, + ) + self.assertEqual(result["rate"], 95) + self.assertEqual(result["description"], '78b03c47-b20f-4f91-8161-47340367fb34') + + def test_clear_basket_rates_is_connected(self): + result = basket_modified.disconnect(clear_basket_rates) + self.assertTrue(result) + basket_modified.connect(clear_basket_rates) + + def test_clear_basket_rates(self): + self.assertTrue(ShippingRate.objects.filter(pk__in=[self.rate1.pk, self.rate2.pk, self.rate3.pk]).exists()) + clear_basket_rates(sender=ShippingRate, basket_id=self.bid) + self.assertFalse(ShippingRate.objects.filter(pk__in=[self.rate1.pk, self.rate2.pk]).exists()) + self.assertTrue(ShippingRate.objects.filter(pk__in=[self.rate3.pk]).exists()) + + +class AddressModifiedSignalTest(LongclawTestCase): + """Round trip API tests + """ + def setUp(self): + self.country = CountryFactory() + self.address = AddressFactory() + self.address_data = { + 'name': 'JANE DOE', + 'line_1': '1600 Pennsylvania Ave NW', + 'city': 'DC', + 'postcode': '20500', + 'country': self.country.pk, + } + + request = RequestFactory().get('/') + request.session = {} + self.bid = bid = basket_id(request) + self.item = BasketItemFactory(basket_id=bid) + BasketItemFactory(basket_id=bid) + + self.ratedAddress = address = AddressFactory() + + self.rate1 = ShippingRate.objects.create( + name='98d17c43-7e20-42bd-b603-a4c83c829c5a', + rate=99, + carrier='8717ca67-4691-4dff-96ec-c43cccd15241', + description='313037e1-644a-4570-808a-f9ba82ecfb34', + basket_id=bid, + ) + + self.rate2 = ShippingRate.objects.create( + name='8e721550-594c-482b-b512-54dc1744dff8', + rate=97, + carrier='4f4cca35-1a7a-47ec-ab38-a9918e0c04af', + description='eacb446d-eb17-4ea7-82c1-ac2f62a53a7d', + basket_id=bid, + destination=address, + ) + + self.rate3 = ShippingRate.objects.create( + name='72991859-dc0b-463e-821a-bf8b04aaed2c', + rate=95, + carrier='0aa3c318-b045-4a96-a456-69b4cc71d46a', + description='78b03c47-b20f-4f91-8161-47340367fb34', + destination=address, + ) + + def test_clear_address_rates_is_connected(self): + result = address_modified.disconnect(clear_address_rates) + self.assertTrue(result) + address_modified.connect(clear_address_rates) + + def test_clear_address_rates(self): + self.assertTrue(ShippingRate.objects.filter(pk__in=[self.rate1.pk, self.rate2.pk, self.rate3.pk]).exists()) + clear_address_rates(sender=ShippingRate, instance=self.ratedAddress) + self.assertTrue(ShippingRate.objects.filter(pk__in=[self.rate1.pk]).exists()) + self.assertFalse(ShippingRate.objects.filter(pk__in=[self.rate2.pk, self.rate3.pk]).exists()) + + def test_create_address_sends_signal(self): + with catch_signal(address_modified) as handler: + self.post_test(self.address_data, 'longclaw_address_list') + + handler.assert_called_once_with( + instance=mock.ANY, + sender=Address, + signal=address_modified, + ) + + def test_put_address_sends_signal(self): + serializer = AddressSerializer(self.address) + data = {} + data.update(serializer.data) + data.update(self.address_data) + + self.assertNotEqual(self.address.postcode, '20500') + + with catch_signal(address_modified) as handler: + response = self.put_test(data, 'longclaw_address_detail', urlkwargs={'pk': self.address.pk}) + + self.assertEqual('20500', response.data['postcode']) + + handler.assert_called_once_with( + instance=self.address, + sender=Address, + signal=address_modified, + ) + + def test_patch_address_sends_signal(self): + self.assertNotEqual(self.address.postcode, '20500') + + with catch_signal(address_modified) as handler: + response = self.patch_test(self.address_data, 'longclaw_address_detail', urlkwargs={'pk': self.address.pk}) + + self.assertEqual('20500', response.data['postcode']) + + handler.assert_called_once_with( + instance=self.address, + sender=Address, + signal=address_modified, + ) + + def test_delete_address_sends_signal(self): + with catch_signal(address_modified) as handler: + self.del_test('longclaw_address_detail', urlkwargs={'pk': self.address.pk}) + + handler.assert_called_once_with( + instance=mock.ANY, + sender=Address, + signal=address_modified, + ) + + class AddressFormTest(TestCase): def setUp(self): @@ -48,3 +356,337 @@ class AddressFormTest(TestCase): form = AddressForm(data=model_to_dict(self.address)) self.assertTrue(form.is_valid(), form.errors.as_json()) + +@mock.patch('longclaw.shipping.api.basket_id', return_value='foo') +class ShippingCostEndpointTest(LongclawTestCase): + def setUp(self): + self.country = CountryFactory() + self.address = AddressFactory() + + request = RequestFactory().get('/') + request.session = {} + + self.basket_id = 'foo' + BasketItemFactory(basket_id=self.basket_id) + BasketItemFactory(basket_id=self.basket_id) + + self.rate1 = ShippingRate.objects.create( + name='rate1', + rate=99, + carrier='rate1c', + description='rate1d', + basket_id=self.basket_id, + ) + + self.rate2 = ShippingRate.objects.create( + name='rate2', + rate=97, + carrier='rate2c', + description='rate2d', + basket_id=self.basket_id, + destination=self.address, + ) + + self.rate3 = ShippingRate.objects.create( + name='rate3', + rate=95, + carrier='rate3c', + description='rate3d', + destination=self.address, + ) + + self.rate4 = ShippingRate.objects.create( + name='rate4', + rate=93, + carrier='rate4c', + description='rate4d', + ) + self.rate4.countries.add(self.country) + + def test_get_rate1_cost(self, basket_id_func): + params = dict( + country_code=self.country.pk, + shipping_rate_name='rate1', + ) + response = self.get_test('longclaw_shipping_cost', params=params) + self.assertEqual(response.data, {'description': 'rate1d', 'rate': Decimal('99.00'), 'carrier': 'rate1c'}) + + def test_get_rate2_cost(self, basket_id_func): + params = dict( + destination=self.address.pk, + shipping_rate_name='rate2', + ) + response = self.get_test('longclaw_shipping_cost', params=params) + self.assertEqual(response.data, {'description': 'rate2d', 'rate': Decimal('97.00'), 'carrier': 'rate2c'}) + + def test_get_rate3_cost(self, basket_id_func): + params = dict( + destination=self.address.pk, + shipping_rate_name='rate3', + ) + response = self.get_test('longclaw_shipping_cost', params=params) + self.assertEqual(response.data, {'description': 'rate3d', 'rate': Decimal('95.00'), 'carrier': 'rate3c'}) + + def test_get_rate4_cost(self, basket_id_func): + # + # destination + # + params = dict( + country_code=self.country.pk, + shipping_rate_name='rate4', + ) + response = self.get_test('longclaw_shipping_cost', params=params) + self.assertEqual(response.data, {'description': 'rate4d', 'rate': Decimal('93.00'), 'carrier': 'rate4c'}) + + +class ShippingRateProcessorTest(LongclawTestCase): + def setUp(self): + pass + + def test_process_rates_not_implemented(self): + with self.assertRaises(NotImplementedError): + ShippingRateProcessor().process_rates() + + def test_get_rates_cache(self): + rates = [ + ShippingRate(pk=1), + ShippingRate(pk=2), + ShippingRate(pk=3), + ] + + rates_alt = [ + ShippingRate(pk=4), + ShippingRate(pk=5), + ShippingRate(pk=6), + ] + + self.assertNotEqual(rates, rates_alt) + + processor = ShippingRateProcessor() + processor.process_rates = lambda **kwargs: rates + processor.get_rates_cache_key = lambda **kwargs: force_text('foo') + + self.assertEqual(processor.get_rates(), rates) + + processor.process_rates = lambda **kwargs: rates_alt + + self.assertEqual(processor.get_rates(), rates) + + processor.get_rates_cache_key = lambda **kwargs: force_text('bar') + + self.assertEqual(processor.get_rates(), rates_alt) + + +class ShippingRateProcessorAPITest(LongclawTestCase): + def setUp(self): + self.country = CountryFactory() + self.country.iso = '11' + self.country.save() + + self.address = AddressFactory() + self.address.country = self.country + self.address.save() + + self.processor = ShippingRateProcessor() + self.processor.save() + self.processor.countries.add(self.country) + + def test_shipping_option_endpoint_without_destination(self): + params = { + 'country_code': self.country.pk, + } + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params, success_expected=False) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data['message'], "Destination address is required for rates to 11.") + + def test_shipping_option_endpoint_gets_processor_rates(self): + params = { + 'destination': self.address.pk, + } + with mock.patch('longclaw.shipping.api.ShippingRateProcessor.get_rates') as mocked_get_rates: + mocked_get_rates.return_value = [] + + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertTrue(mocked_get_rates.called) + self.assertEqual(mocked_get_rates.call_count, 1) + + processor = ShippingRateProcessor() + processor.save() + processor.countries.add(self.country) + + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertEqual(mocked_get_rates.call_count, 3) + + +class ShippingOptionEndpointTest(LongclawTestCase): + def setUp(self): + self.country = CountryFactory() + self.country2 = CountryFactory() + self.address = AddressFactory() + self.address2 = AddressFactory() + self.address2.country = self.country2 + self.address2.save() + + self.assertNotEqual(self.country.pk, self.country2.pk, 'Try again. Random got you!') + + + request = RequestFactory().get('/') + request.session = {} + + self.basket_id = 'bar' + BasketItemFactory(basket_id=self.basket_id) + BasketItemFactory(basket_id=self.basket_id) + + self.rate1 = ShippingRate.objects.create( + name='rate1', + rate=99, + carrier='rate1c', + description='rate1d', + basket_id=self.basket_id, + ) + + self.rate2 = ShippingRate.objects.create( + name='rate2', + rate=97, + carrier='rate2c', + description='rate2d', + basket_id=self.basket_id, + destination=self.address, + ) + + self.rate3 = ShippingRate.objects.create( + name='rate3', + rate=95, + carrier='rate3c', + description='rate3d', + destination=self.address, + ) + + self.rate4 = ShippingRate.objects.create( + name='rate4', + rate=93, + carrier='rate4c', + description='rate4d', + ) + self.rate4.countries.add(self.country) + + self.rate5 = ShippingRate.objects.create( + name='rate5', + rate=95, + carrier='rate5c', + description='rate5d', + destination=self.address2, + ) + + @mock.patch('longclaw.shipping.api.basket_id', return_value='bar') + def test_get_rate1rate4_option_urlkwargs(self, basket_id_func): + """ + We expect rate1 because of the basket id. + We expect rate4 because of the country. + """ + expected_pks = [self.rate1.pk, self.rate4.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + response = self.get_test('longclaw_shipping_options', urlkwargs={'country': self.country.pk}) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + @mock.patch('longclaw.shipping.api.basket_id', return_value='bar') + def test_get_rate1rate4_option(self, basket_id_func): + """ + We expect rate1 because of the basket id. + We expect rate4 because of the country. + """ + expected_pks = [self.rate1.pk, self.rate4.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + params = { + 'country_code': self.country.pk, + } + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + @mock.patch('longclaw.shipping.api.basket_id', return_value='bar') + def test_get_rate1rate2rate3_option(self, basket_id_func): + """ + We expect rate1 because of the basket id. + We expect rate2 because of the destination address and basket id. + We expect rate3 because of the destination address. + """ + expected_pks = [self.rate1.pk, self.rate2.pk, self.rate3.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + params = { + 'destination': self.address.pk, + } + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + def test_get_rate5_option(self): + """ + We expect rate5 because of the destination address. + """ + expected_pks = [self.rate5.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + params = { + 'destination': self.address2.pk, + } + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + def test_get_rate4_option(self): + """ + We expect rate4 because of the country. + """ + expected_pks = [self.rate4.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + params = { + 'country_code': self.country.pk, + } + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + def test_get_rate4_option_urlkwargs(self): + """ + We expect rate4 because of the country. + """ + expected_pks = [self.rate4.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + response = self.get_test('longclaw_shipping_options', urlkwargs={'country': self.country.pk}) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + @mock.patch('longclaw.shipping.api.basket_id', return_value='bar') + def test_get_rate1_option(self, basket_id_func): + """ + We expect rate1 because of the basket. + """ + expected_pks = [self.rate1.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + params = { + 'country_code': self.country2.pk, + } + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + @mock.patch('longclaw.shipping.api.basket_id', return_value='bar') + def test_get_rate6_option(self, basket_id_func): + """ + We expect rate6 because of the basket id and address. + """ + expected_pks = [self.rate1.pk] + serializer = ShippingRateSerializer(ShippingRate.objects.filter(pk__in=expected_pks), many=True) + params = { + 'country_code': self.country2.pk, + } + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + self.assertEqual(len(response.data), len(expected_pks)) + self.assertEqual(response.data, serializer.data) + + + + + + diff --git a/longclaw/shipping/urls.py b/longclaw/shipping/urls.py index cec6f73..4c14139 100644 --- a/longclaw/shipping/urls.py +++ b/longclaw/shipping/urls.py @@ -28,5 +28,8 @@ urlpatterns = [ name='longclaw_shipping_countries'), url(API_URL_PREFIX + r'shipping/countries/(?P[a-zA-Z]+)/$', api.shipping_options, - name='longclaw_shipping_options') + name='longclaw_shipping_options'), + url(API_URL_PREFIX + r'shipping/options/$', + api.shipping_options, + name='longclaw_applicable_shipping_rate_list') ] diff --git a/longclaw/shipping/utils.py b/longclaw/shipping/utils.py index f3d2bea..7c16dcd 100644 --- a/longclaw/shipping/utils.py +++ b/longclaw/shipping/utils.py @@ -1,3 +1,5 @@ +from django.db.models import Q + from longclaw.shipping import models @@ -8,10 +10,19 @@ class InvalidShippingRate(Exception): class InvalidShippingCountry(Exception): pass -def get_shipping_cost(settings, country_code=None, name=None): + +class InvalidShippingDestination(Exception): + pass + + +def get_shipping_cost(settings, country_code=None, name=None, basket_id=None, destination=None): """Return the shipping cost for a given country code and shipping option (shipping rate name) """ + if not country_code and destination: + country_code = destination.country.pk + shipping_rate = None + invalid_country = False if settings.default_shipping_enabled: shipping_rate = { "rate": settings.default_shipping_rate, @@ -19,17 +30,42 @@ def get_shipping_cost(settings, country_code=None, name=None): "carrier": settings.default_shipping_carrier } elif not country_code: - raise InvalidShippingCountry + invalid_country = True if country_code: qrs = models.ShippingRate.objects.filter(countries__in=[country_code], name=name) count = qrs.count() if count == 1: shipping_rate_qrs = qrs[0] - else: - raise InvalidShippingRate() - shipping_rate = { - "rate": shipping_rate_qrs.rate, - "description": shipping_rate_qrs.description, - "carrier": shipping_rate_qrs.carrier} + shipping_rate = { + "rate": shipping_rate_qrs.rate, + "description": shipping_rate_qrs.description, + "carrier": shipping_rate_qrs.carrier} + + if basket_id or destination: + q = Q() + + if destination and basket_id: + q.add(Q(destination=destination, basket_id=basket_id), Q.OR) + + if destination: + q.add(Q(destination=destination, basket_id=''), Q.OR) + + if basket_id: + q.add(Q(destination=None, basket_id=basket_id), Q.OR) + + qrs = models.ShippingRate.objects.filter(name=name).filter(q) + count = qrs.count() + if count == 1: + shipping_rate_qrs = qrs[0] + shipping_rate = { + "rate": shipping_rate_qrs.rate, + "description": shipping_rate_qrs.description, + "carrier": shipping_rate_qrs.carrier} + + if not shipping_rate: + if invalid_country: + raise InvalidShippingCountry + raise InvalidShippingRate() + return shipping_rate diff --git a/longclaw/tests/settings.py b/longclaw/tests/settings.py index 7c33e4f..35ec3d0 100644 --- a/longclaw/tests/settings.py +++ b/longclaw/tests/settings.py @@ -56,6 +56,7 @@ INSTALLED_APPS = [ 'longclaw.stats', 'longclaw.contrib.productrequests', 'longclaw.tests.testproducts', + 'longclaw.tests.trivialrates', ] SITE_ID = 1 diff --git a/longclaw/tests/trivialrates/__init__.py b/longclaw/tests/trivialrates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/longclaw/tests/trivialrates/migrations/0001_initial.py b/longclaw/tests/trivialrates/migrations/0001_initial.py new file mode 100644 index 0000000..5aea0d9 --- /dev/null +++ b/longclaw/tests/trivialrates/migrations/0001_initial.py @@ -0,0 +1,27 @@ +# Generated by Django 2.1.7 on 2019-03-23 17:15 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('shipping', '0003_auto_20190322_1429'), + ] + + operations = [ + migrations.CreateModel( + name='TrivialShippingRateProcessor', + fields=[ + ('shippingrateprocessor_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='shipping.ShippingRateProcessor')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('shipping.shippingrateprocessor',), + ), + ] diff --git a/longclaw/tests/trivialrates/migrations/__init__.py b/longclaw/tests/trivialrates/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/longclaw/tests/trivialrates/models.py b/longclaw/tests/trivialrates/models.py new file mode 100644 index 0000000..c93a725 --- /dev/null +++ b/longclaw/tests/trivialrates/models.py @@ -0,0 +1,48 @@ +import hashlib + +from django.utils.encoding import force_bytes, force_text +from longclaw.shipping.models import ShippingRateProcessor, ShippingRate +from longclaw.basket.models import BasketItem + + +class TrivialShippingRateProcessor(ShippingRateProcessor): + def process_rates(self, **kwargs): + destination = kwargs['destination'] + basket_id = kwargs['basket_id'] + + item_count = BasketItem.objects.filter(basket_id=basket_id).count() + + rates = [] + + quotes = [] + + if 0 < item_count: + quotes.append((item_count * 2, 'turtle')) + + if 1 < item_count: + quotes.append((item_count * 4, 'rabbit')) + + if 2 < item_count: + quotes.append((item_count * 16, 'cheetah')) + + for amount, speed in quotes: + name = self.get_processed_rate_name(destination, basket_id, speed) + lookups = dict(name=name) + values = dict( + rate=amount, + carrier='TrivialShippingRateProcessor', + description='Delivered with {} speed'.format(speed), + basket_id=basket_id, + destination=destination, + processor=self, + ) + + rate = ShippingRate.objects.update_or_create(defaults=values, **lookups) + rates.append(rate) + + return rates + + def get_processed_rate_name(self, destination, basket_id, speed): + name_long = 'TrivialShippingRateProcessor-{}-{}-{}'.format(destination.pk, basket_id, speed) + name = hashlib.md5(force_bytes(name_long)).hexdigest() + return force_text(name) diff --git a/longclaw/tests/trivialrates/tests.py b/longclaw/tests/trivialrates/tests.py new file mode 100644 index 0000000..b2a4150 --- /dev/null +++ b/longclaw/tests/trivialrates/tests.py @@ -0,0 +1,117 @@ +import mock + +from longclaw.tests.utils import LongclawTestCase, AddressFactory, CountryFactory, BasketItemFactory +from longclaw.shipping.models import Address, ShippingRate, ShippingRateProcessor + +from .models import TrivialShippingRateProcessor + + +@mock.patch('longclaw.shipping.api.basket_id', return_value='foo') +class TrivialShippingRateProcessorAPITest(LongclawTestCase): + def setUp(self): + self.country = CountryFactory() + self.country.iso = '11' + self.country.save() + + self.address = AddressFactory() + self.address.country = self.country + self.address.save() + + self.processor = TrivialShippingRateProcessor() + self.processor.save() + self.processor.countries.add(self.country) + + def add_item_to_basket(self): + BasketItemFactory(basket_id='foo') + + def assert_contains_turtle(self, response): + self.assertContains(response, 'turtle') + + def assert_contains_rabbit(self, response): + self.assertContains(response, 'rabbit') + + def assert_contains_cheetah(self, response): + self.assertContains(response, 'cheetah') + + def assert_not_contains_turtle(self, response): + self.assertNotContains(response, 'turtle') + + def assert_not_contains_rabbit(self, response): + self.assertNotContains(response, 'rabbit') + + def assert_not_contains_cheetah(self, response): + self.assertNotContains(response, 'cheetah') + + def test_zero_rates(self, m1): + params = { + 'destination': self.address.pk, + } + + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + + self.assertEqual(len(response.data), 0) + self.assert_not_contains_turtle(response) + self.assert_not_contains_rabbit(response) + self.assert_not_contains_cheetah(response) + + def test_one_rate(self, m1): + self.add_item_to_basket() + + params = { + 'destination': self.address.pk, + } + + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + + self.assertEqual(len(response.data), 1, response.content) + self.assert_contains_turtle(response) + self.assert_not_contains_rabbit(response) + self.assert_not_contains_cheetah(response) + + return response + + def test_one_rate_cost(self, m1): + rate_list_response = self.test_one_rate() + rate = rate_list_response.data[0] + self.assertIn('name', rate) + name = rate['name'] + + params = dict( + destination=self.address.pk, + shipping_rate_name=name, + ) + response = self.get_test('longclaw_shipping_cost', params=params) + self.assert_contains_turtle(response) + self.assertIn('rate', response.data) + self.assertEqual(response.data['rate'], 2) + + def test_two_rates(self, m1): + self.add_item_to_basket() + self.add_item_to_basket() + + params = { + 'destination': self.address.pk, + } + + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + + self.assertEqual(len(response.data), 2, response.content) + self.assert_contains_turtle(response) + self.assert_contains_rabbit(response) + self.assert_not_contains_cheetah(response) + + def test_three_rates(self, m1): + self.add_item_to_basket() + self.add_item_to_basket() + self.add_item_to_basket() + + params = { + 'destination': self.address.pk, + } + + response = self.get_test('longclaw_applicable_shipping_rate_list', params=params) + + self.assertEqual(len(response.data), 3, response.content) + self.assert_contains_turtle(response) + self.assert_contains_rabbit(response) + self.assert_contains_cheetah(response) diff --git a/longclaw/tests/utils.py b/longclaw/tests/utils.py index 433b855..9f11d71 100644 --- a/longclaw/tests/utils.py +++ b/longclaw/tests/utils.py @@ -1,3 +1,6 @@ +from unittest import mock +from contextlib import contextmanager + import factory from django.urls import reverse_lazy @@ -12,6 +15,19 @@ from longclaw.orders.models import Order from longclaw.shipping.models import Address, Country, ShippingRate from longclaw.utils import ProductVariant, maybe_get_product_model + +@contextmanager +def catch_signal(signal): + """ + Catch django signal and return the mocked call. + https://medium.freecodecamp.org/how-to-testing-django-signals-like-a-pro-c7ed74279311 + """ + handler = mock.Mock() + signal.connect(handler) + yield handler + signal.disconnect(handler) + + class OrderFactory(factory.django.DjangoModelFactory): class Meta: model = Order @@ -89,18 +105,20 @@ class BasketItemFactory(factory.django.DjangoModelFactory): class LongclawTestCase(APITestCase): - def get_test(self, urlname, urlkwargs=None, **kwargs): + def get_test(self, urlname, urlkwargs=None, params=None, success_expected=True, **kwargs): """ Submit a GET request and assert the response status code is 200 Arguments: urlname (str): The url name to pass to the 'reverse_lazy' function urlkwargs (dict): The `kwargs` parameter to pass to the `reverse_lazy` function """ - response = self.client.get(reverse_lazy(urlname, kwargs=urlkwargs), **kwargs) - self.assertEqual(response.status_code, status.HTTP_200_OK) + params = params or {} + response = self.client.get(reverse_lazy(urlname, kwargs=urlkwargs), params, **kwargs) + if success_expected: + self.assertTrue(status.is_success(response.status_code), response.content) return response - def post_test(self, data, urlname, urlkwargs=None, **kwargs): + def post_test(self, data, urlname, urlkwargs=None, success_expected=True, **kwargs): """ Submit a POST request and assert the response status code is 201 Arguments: @@ -109,23 +127,26 @@ class LongclawTestCase(APITestCase): urlkwargs (dict): The `kwargs` parameter to pass to the `reverse_lazy` function """ response = self.client.post(reverse_lazy(urlname, kwargs=urlkwargs), data, **kwargs) - self.assertIn(response.status_code, - (status.HTTP_201_CREATED, status.HTTP_200_OK, status.HTTP_204_NO_CONTENT)) + if success_expected: + self.assertTrue(status.is_success(response.status_code), response.content) return response - def patch_test(self, data, urlname, urlkwargs=None, **kwargs): + def patch_test(self, data, urlname, urlkwargs=None, success_expected=True, **kwargs): """ Submit a PATCH request and assert the response status code is 200 """ response = self.client.patch(reverse_lazy(urlname, kwargs=urlkwargs), data, **kwargs) - self.assertEqual(response.status_code, status.HTTP_200_OK) + if success_expected: + self.assertTrue(status.is_success(response.status_code), response.content) return response - def put_test(self, data, urlname, urlkwargs=None, **kwargs): + def put_test(self, data, urlname, urlkwargs=None, success_expected=True, **kwargs): response = self.client.put(reverse_lazy(urlname, kwargs=urlkwargs), data, **kwargs) - self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + if success_expected: + self.assertTrue(status.is_success(response.status_code), response.content) return response - def del_test(self, urlname, urlkwargs=None, **kwargs): + def del_test(self, urlname, urlkwargs=None, success_expected=True, **kwargs): response = self.client.delete(reverse_lazy(urlname, kwargs=urlkwargs), **kwargs) - self.assertEqual(response.status_code, status.HTTP_200_OK) + if success_expected: + self.assertTrue(status.is_success(response.status_code), response.content) return response diff --git a/setup.py b/setup.py index 577da17..bd6669d 100755 --- a/setup.py +++ b/setup.py @@ -92,7 +92,8 @@ setup( 'django-countries==5.5', 'django-extensions==2.2.1', 'djangorestframework==3.10.3', - 'django-ipware==2.1.0' + 'django-ipware==2.1.0', + 'django-polymorphic==2.0.3', ], license="MIT", zip_safe=False, diff --git a/vagrant/.gitignore b/vagrant/.gitignore new file mode 100644 index 0000000..a977916 --- /dev/null +++ b/vagrant/.gitignore @@ -0,0 +1 @@ +.vagrant/