From ef9d5b852d49d0963c9eb9f3863ff066fe494092 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 11:46:46 +0100 Subject: [PATCH 01/22] Use ViewSet as base class for API endpoints. --- setup.py | 1 + wagtail/contrib/wagtailapi/endpoints.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 73849d4880..bd8fb70fc8 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ install_requires = [ "django-modelcluster>=0.6", "django-taggit>=0.13.0", "django-treebeard==3.0", + "djangorestframework==3.1.3", "Pillow>=2.6.1", "beautifulsoup4>=4.3.2", "html5lib==0.999", diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 318c2c3c47..c50d7eed69 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -10,6 +10,7 @@ from django.utils.encoding import force_text from django.shortcuts import get_object_or_404 from django.conf.urls import url from django.conf import settings +from rest_framework.viewsets import ViewSet from wagtail.wagtailcore.models import Page from wagtail.wagtailimages.models import get_image_model @@ -94,7 +95,7 @@ def get_api_data(obj, fields): continue -class BaseAPIEndpoint(object): +class BaseAPIEndpoint(ViewSet): known_query_parameters = frozenset([ 'limit', 'offset', From b6a4318379ac90e5b62efa26361c9933b28cea8d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 13:40:19 +0100 Subject: [PATCH 02/22] Refactor to use Responses and Renderers. --- wagtail/contrib/wagtailapi/api.py | 86 ++----------------------- wagtail/contrib/wagtailapi/endpoints.py | 77 +++++++++++++--------- wagtail/contrib/wagtailapi/renderers.py | 49 ++++++++++++++ wagtail/contrib/wagtailapi/utils.py | 22 +++++++ 4 files changed, 122 insertions(+), 112 deletions(-) create mode 100644 wagtail/contrib/wagtailapi/renderers.py diff --git a/wagtail/contrib/wagtailapi/api.py b/wagtail/contrib/wagtailapi/api.py index 586e3d5249..33877429f6 100644 --- a/wagtail/contrib/wagtailapi/api.py +++ b/wagtail/contrib/wagtailapi/api.py @@ -1,95 +1,21 @@ -import json -from functools import wraps - from django.conf.urls import url, include -from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotFound, Http404 -from django.core.serializers.json import DjangoJSONEncoder -from django.core.urlresolvers import reverse -from taggit.managers import _TaggableManager -from taggit.models import Tag - -from wagtail.utils.urlpatterns import decorate_urlpatterns -from wagtail.wagtailcore.blocks import StreamValue - -from .endpoints import URLPath, ObjectDetailURL, PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint -from .utils import BadRequestError, get_base_url - - -def get_full_url(request, path): - base_url = get_base_url(request) or '' - return base_url + path +from .endpoints import PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint class API(object): def __init__(self, endpoints): self.endpoints = endpoints - def find_model_detail_view(self, model): - for endpoint_name, endpoint in self.endpoints.items(): - if endpoint.has_model(model): - return 'wagtailapi_v1:%s:detail' % endpoint_name - - def make_response(self, request, data, response_cls=HttpResponse): - api = self - - class WagtailAPIJSONEncoder(DjangoJSONEncoder): - def default(self, o): - if isinstance(o, _TaggableManager): - return list(o.all()) - elif isinstance(o, Tag): - return o.name - elif isinstance(o, URLPath): - return get_full_url(request, o.path) - elif isinstance(o, ObjectDetailURL): - view = api.find_model_detail_view(o.model) - - if view: - return get_full_url(request, reverse(view, args=(o.pk, ))) - else: - return None - elif isinstance(o, StreamValue): - return o.stream_block.get_prep_value(o) - else: - return super(WagtailAPIJSONEncoder, self).default(o) - - return response_cls( - json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder), - content_type='application/json' - ) - - def api_view(self, view): - """ - This is a decorator that is applied to all API views. - - It is responsible for serialising the responses from the endpoints - and handling errors. - """ - @wraps(view) - def wrapper(request, *args, **kwargs): - # Catch exceptions and format them as JSON documents - try: - return self.make_response(request, view(request, *args, **kwargs)) - except Http404 as e: - return self.make_response(request, { - 'message': str(e) - }, response_cls=HttpResponseNotFound) - except BadRequestError as e: - return self.make_response(request, { - 'message': str(e) - }, response_cls=HttpResponseBadRequest) - - return wrapper - def get_urlpatterns(self): - return decorate_urlpatterns([ + return [ url(r'^%s/' % name, include(endpoint.get_urlpatterns(), namespace=name)) for name, endpoint in self.endpoints.items() - ], self.api_view) + ] v1 = API({ - 'pages': PagesAPIEndpoint(), - 'images': ImagesAPIEndpoint(), - 'documents': DocumentsAPIEndpoint(), + 'pages': PagesAPIEndpoint, + 'images': ImagesAPIEndpoint, + 'documents': DocumentsAPIEndpoint, }) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index c50d7eed69..4d6e15b725 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -10,6 +10,10 @@ from django.utils.encoding import force_text from django.shortcuts import get_object_or_404 from django.conf.urls import url from django.conf import settings +from django.http import Http404 + +from rest_framework import status +from rest_framework.response import Response from rest_framework.viewsets import ViewSet from wagtail.wagtailcore.models import Page @@ -19,29 +23,8 @@ from wagtail.wagtailcore.utils import resolve_model_string from wagtail.wagtailsearch.backends import get_search_backend from wagtail.utils.compat import get_related_model -from .utils import BadRequestError - - -class URLPath(object): - """ - This class represents a URL path that should be converted to a full URL. - - It is used when the domain that should be used is not known at the time - the URL was generated. It will get resolved to a full URL during - serialisation in api.py. - - One example use case is the documents endpoint adding download URLs into - the JSON. The endpoint does not know the domain name to use at the time so - returns one of these instead. - """ - def __init__(self, path): - self.path = path - - -class ObjectDetailURL(object): - def __init__(self, model, pk): - self.model = model - self.pk = pk +from .renderers import WagtailJSONRenderer +from .utils import BadRequestError, URLPath, ObjectDetailURL def get_api_data(obj, fields): @@ -96,6 +79,8 @@ def get_api_data(obj, fields): class BaseAPIEndpoint(ViewSet): + renderer_classes = [WagtailJSONRenderer] + known_query_parameters = frozenset([ 'limit', 'offset', @@ -104,6 +89,15 @@ class BaseAPIEndpoint(ViewSet): 'search', ]) + def handle_exception(self, exc): + if isinstance(exc, Http404): + data = {'message': str(exc)} + return Response(data, status=status.HTTP_404_NOT_FOUND) + elif isinstance(exc, BadRequestError): + data = {'message': str(exc)} + return Response(data, status=status.HTTP_400_BAD_REQUEST) + return super(BaseAPIEndpoint, self).handle_exception(exc) + def listing_view(self, request): return NotImplemented @@ -300,15 +294,28 @@ class BaseAPIEndpoint(ViewSet): return queryset[start:stop] - def get_urlpatterns(self): + @classmethod + def get_urlpatterns(cls): """ This returns a list of URL patterns for the endpoint """ return [ - url(r'^$', self.listing_view, name='listing'), - url(r'^(\d+)/$', self.detail_view, name='detail'), + url(r'^$', cls.as_view({'get': 'listing_view'}), name='listing'), + url(r'^(\d+)/$', cls.as_view({'get': 'detail_view'}), name='detail'), ] + def find_model_detail_view(self, model): + # TODO: Needs refactoring. This is currently duplicated, and also + # does a bit of a dance around instantiating these classes. + endpoints = { + 'pages': PagesAPIEndpoint(), + 'images': ImagesAPIEndpoint(), + 'documents': DocumentsAPIEndpoint(), + } + for endpoint_name, endpoint in endpoints.items(): + if endpoint.has_model(model): + return 'wagtailapi_v1:%s:detail' % endpoint_name + def has_model(self, model): return False @@ -443,7 +450,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - return OrderedDict([ + data = OrderedDict([ ('meta', OrderedDict([ ('total_count', total_count), ])), @@ -452,10 +459,12 @@ class PagesAPIEndpoint(BaseAPIEndpoint): for page in queryset ]), ]) + return Response(data) def detail_view(self, request, pk): page = get_object_or_404(self.get_queryset(request), pk=pk).specific - return self.serialize_object(request, page, all_fields=True, show_details=True) + data = self.serialize_object(request, page, all_fields=True, show_details=True) + return Response(data) def has_model(self, model): return issubclass(model, Page) @@ -497,7 +506,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - return OrderedDict([ + data = OrderedDict([ ('meta', OrderedDict([ ('total_count', total_count), ])), @@ -506,10 +515,12 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): for image in queryset ]), ]) + return Response(data) def detail_view(self, request, pk): image = get_object_or_404(self.get_queryset(request), pk=pk) - return self.serialize_object(request, image, all_fields=True) + data = self.serialize_object(request, image, all_fields=True) + return Response(data) def has_model(self, model): return model == self.model @@ -555,7 +566,7 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - return OrderedDict([ + data = OrderedDict([ ('meta', OrderedDict([ ('total_count', total_count), ])), @@ -564,10 +575,12 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): for document in queryset ]), ]) + return Response(data) def detail_view(self, request, pk): document = get_object_or_404(Document, pk=pk) - return self.serialize_object(request, document, all_fields=True, show_details=True) + data = self.serialize_object(request, document, all_fields=True, show_details=True) + return Response(data) def has_model(self, model): return model == Document diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py new file mode 100644 index 0000000000..c58ea32c63 --- /dev/null +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -0,0 +1,49 @@ +import json + +from django.core.serializers.json import DjangoJSONEncoder +from django.core.urlresolvers import reverse + +from rest_framework import renderers + +from taggit.managers import _TaggableManager +from taggit.models import Tag + +from wagtail.wagtailcore.blocks import StreamValue + +from .utils import URLPath, ObjectDetailURL, get_base_url + + +def get_full_url(request, path): + base_url = get_base_url(request) or '' + return base_url + path + + +class WagtailJSONRenderer(renderers.BaseRenderer): + media_type = 'application/json' + charset = None + + def render(self, data, media_type=None, renderer_context=None): + endpoint = renderer_context['view'] + request = renderer_context['request'] + + class WagtailAPIJSONEncoder(DjangoJSONEncoder): + def default(self, o): + if isinstance(o, _TaggableManager): + return list(o.all()) + elif isinstance(o, Tag): + return o.name + elif isinstance(o, URLPath): + return get_full_url(request, o.path) + elif isinstance(o, ObjectDetailURL): + view = endpoint.find_model_detail_view(o.model) + + if view: + return get_full_url(request, reverse(view, args=(o.pk, ))) + else: + return None + elif isinstance(o, StreamValue): + return o.stream_block.get_prep_value(o) + else: + return super(WagtailAPIJSONEncoder, self).default(o) + + return json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder) diff --git a/wagtail/contrib/wagtailapi/utils.py b/wagtail/contrib/wagtailapi/utils.py index 11af445d28..483e4f51a2 100644 --- a/wagtail/contrib/wagtailapi/utils.py +++ b/wagtail/contrib/wagtailapi/utils.py @@ -14,3 +14,25 @@ def get_base_url(request=None): base_url_parsed = urlparse(base_url) return base_url_parsed.scheme + '://' + base_url_parsed.netloc + + +class URLPath(object): + """ + This class represents a URL path that should be converted to a full URL. + + It is used when the domain that should be used is not known at the time + the URL was generated. It will get resolved to a full URL during + serialisation in api.py. + + One example use case is the documents endpoint adding download URLs into + the JSON. The endpoint does not know the domain name to use at the time so + returns one of these instead. + """ + def __init__(self, path): + self.path = path + + +class ObjectDetailURL(object): + def __init__(self, model, pk): + self.model = model + self.pk = pk From aee387e2c23967a78d4e2da4cc7785e78b394759 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 13:58:03 +0100 Subject: [PATCH 03/22] Drop 'api' module. --- wagtail/contrib/wagtailapi/api.py | 21 --------------------- wagtail/contrib/wagtailapi/urls.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 23 deletions(-) delete mode 100644 wagtail/contrib/wagtailapi/api.py diff --git a/wagtail/contrib/wagtailapi/api.py b/wagtail/contrib/wagtailapi/api.py deleted file mode 100644 index 33877429f6..0000000000 --- a/wagtail/contrib/wagtailapi/api.py +++ /dev/null @@ -1,21 +0,0 @@ -from django.conf.urls import url, include - -from .endpoints import PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint - - -class API(object): - def __init__(self, endpoints): - self.endpoints = endpoints - - def get_urlpatterns(self): - return [ - url(r'^%s/' % name, include(endpoint.get_urlpatterns(), namespace=name)) - for name, endpoint in self.endpoints.items() - ] - - -v1 = API({ - 'pages': PagesAPIEndpoint, - 'images': ImagesAPIEndpoint, - 'documents': DocumentsAPIEndpoint, -}) diff --git a/wagtail/contrib/wagtailapi/urls.py b/wagtail/contrib/wagtailapi/urls.py index 1aa914e369..492772f28e 100644 --- a/wagtail/contrib/wagtailapi/urls.py +++ b/wagtail/contrib/wagtailapi/urls.py @@ -2,9 +2,16 @@ from __future__ import absolute_import from django.conf.urls import url, include -from . import api +from .endpoints import PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint + + +v1 = [ + url(r'^pages/', include(PagesAPIEndpoint.get_urlpatterns(), namespace='pages')), + url(r'^images/', include(ImagesAPIEndpoint.get_urlpatterns(), namespace='images')), + url(r'^documents/', include(DocumentsAPIEndpoint.get_urlpatterns(), namespace='documents')) +] urlpatterns = [ - url(r'^v1/', include(api.v1.get_urlpatterns(), namespace='wagtailapi_v1')), + url(r'^v1/', include(v1, namespace='wagtailapi_v1')), ] From e1978f6606ab6ed57ad822996593d1bfa21c3dfe Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 16:17:40 +0100 Subject: [PATCH 04/22] Refactor filters --- wagtail/contrib/wagtailapi/endpoints.py | 177 +++--------------------- wagtail/contrib/wagtailapi/filters.py | 148 ++++++++++++++++++++ 2 files changed, 168 insertions(+), 157 deletions(-) create mode 100644 wagtail/contrib/wagtailapi/filters.py diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 4d6e15b725..594465eaf8 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -14,15 +14,18 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response -from rest_framework.viewsets import ViewSet +from rest_framework.viewsets import GenericViewSet from wagtail.wagtailcore.models import Page from wagtail.wagtailimages.models import get_image_model from wagtail.wagtaildocs.models import Document from wagtail.wagtailcore.utils import resolve_model_string -from wagtail.wagtailsearch.backends import get_search_backend from wagtail.utils.compat import get_related_model +from .filters import ( + FieldsFilter, OrderingFilter, SearchFilter, + ChildOfFilter, DescendantOfFilter +) from .renderers import WagtailJSONRenderer from .utils import BadRequestError, URLPath, ObjectDetailURL @@ -78,8 +81,9 @@ def get_api_data(obj, fields): continue -class BaseAPIEndpoint(ViewSet): +class BaseAPIEndpoint(GenericViewSet): renderer_classes = [WagtailJSONRenderer] + filter_classes = [] known_query_parameters = frozenset([ 'limit', @@ -174,98 +178,6 @@ class BaseAPIEndpoint(ViewSet): if unknown_parameters: raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters))) - def do_field_filtering(self, request, queryset): - """ - This performs field level filtering on the result set - Eg: ?title=James Joyce - """ - fields = set(self.get_api_fields(queryset.model)).union({'id'}) - - for field_name, value in request.GET.items(): - if field_name in fields: - field = getattr(queryset.model, field_name, None) - - if isinstance(field, _TaggableManager): - for tag in value.split(','): - queryset = queryset.filter(**{field_name + '__name': tag}) - - # Stick a message on the queryset to indicate that tag filtering has been performed - # This will let the do_search method know that it must raise an error as searching - # and tag filtering at the same time is not supported - queryset._filtered_by_tag = True - else: - queryset = queryset.filter(**{field_name: value}) - - return queryset - - def do_ordering(self, request, queryset): - """ - This applies ordering to the result set - Eg: ?order=title - - It also supports reverse ordering - Eg: ?order=-title - - And random ordering - Eg: ?order=random - """ - if 'order' in request.GET: - # Prevent ordering while searching - if 'search' in request.GET: - raise BadRequestError("ordering with a search query is not supported") - - order_by = request.GET['order'] - - # Random ordering - if order_by == 'random': - # Prevent ordering by random with offset - if 'offset' in request.GET: - raise BadRequestError("random ordering with offset is not supported") - - return queryset.order_by('?') - - # Check if reverse ordering is set - if order_by.startswith('-'): - reverse_order = True - order_by = order_by[1:] - else: - reverse_order = False - - # Add ordering - if order_by == 'id' or order_by in self.get_api_fields(queryset.model): - queryset = queryset.order_by(order_by) - else: - # Unknown field - raise BadRequestError("cannot order by '%s' (unknown field)" % order_by) - - # Reverse order - if reverse_order: - queryset = queryset.reverse() - - return queryset - - def do_search(self, request, queryset): - """ - This performs a full-text search on the result set - Eg: ?search=James Joyce - """ - search_enabled = getattr(settings, 'WAGTAILAPI_SEARCH_ENABLED', True) - - if 'search' in request.GET: - if not search_enabled: - raise BadRequestError("search is disabled") - - # Searching and filtering by tag at the same time is not supported - if getattr(queryset, '_filtered_by_tag', False): - raise BadRequestError("filtering by tag with a search query is not supported") - - search_query = request.GET['search'] - - sb = get_search_backend() - queryset = sb.search(search_query, queryset) - - return queryset - def do_pagination(self, request, queryset): """ This performs limit/offset based pagination on the result set @@ -326,6 +238,10 @@ class PagesAPIEndpoint(BaseAPIEndpoint): 'child_of', 'descendant_of', ]) + filter_backends = [ + FieldsFilter, ChildOfFilter, DescendantOfFilter, + OrderingFilter, SearchFilter + ] def get_queryset(self, request, model=Page): # Get live pages that are not in a private section @@ -385,42 +301,6 @@ class PagesAPIEndpoint(BaseAPIEndpoint): except LookupError: raise BadRequestError("type doesn't exist") - def do_child_of_filter(self, request, queryset): - if 'child_of' in request.GET: - try: - parent_page_id = int(request.GET['child_of']) - assert parent_page_id >= 0 - except (ValueError, AssertionError): - raise BadRequestError("child_of must be a positive integer") - - try: - parent_page = self.get_queryset(request).get(id=parent_page_id) - queryset = queryset.child_of(parent_page) - queryset._filtered_by_child_of = True - return queryset - except Page.DoesNotExist: - raise BadRequestError("parent page doesn't exist") - - return queryset - - def do_descendant_of_filter(self, request, queryset): - if 'descendant_of' in request.GET: - if getattr(queryset, '_filtered_by_child_of', False): - raise BadRequestError("filtering by descendant_of with child_of is not supported") - try: - ancestor_page_id = int(request.GET['descendant_of']) - assert ancestor_page_id >= 0 - except (ValueError, AssertionError): - raise BadRequestError("descendant_of must be a positive integer") - - try: - ancestor_page = self.get_queryset(request).get(id=ancestor_page_id) - return queryset.descendant_of(ancestor_page) - except Page.DoesNotExist: - raise BadRequestError("ancestor page doesn't exist") - - return queryset - def listing_view(self, request): # Get model and queryset model = self.get_model(request) @@ -429,16 +309,8 @@ class PagesAPIEndpoint(BaseAPIEndpoint): # Check query paramters self.check_query_parameters(request, queryset) - # Filtering - queryset = self.do_field_filtering(request, queryset) - queryset = self.do_child_of_filter(request, queryset) - queryset = self.do_descendant_of_filter(request, queryset) - - # Ordering - queryset = self.do_ordering(request, queryset) - - # Search - queryset = self.do_search(request, queryset) + # Filtering, Ancestor/Descendant, Ordering, Search. + queryset = self.filter_queryset(queryset) # Pagination total_count = queryset.count() @@ -472,6 +344,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): class ImagesAPIEndpoint(BaseAPIEndpoint): model = get_image_model() + filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] def get_queryset(self, request): return self.model.objects.all().order_by('id') @@ -487,14 +360,8 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): # Check query paramters self.check_query_parameters(request, queryset) - # Filtering - queryset = self.do_field_filtering(request, queryset) - - # Ordering - queryset = self.do_ordering(request, queryset) - - # Search - queryset = self.do_search(request, queryset) + # Filtering, Ordering, Search. + queryset = self.filter_queryset(queryset) # Pagination total_count = queryset.count() @@ -527,6 +394,8 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): class DocumentsAPIEndpoint(BaseAPIEndpoint): + filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] + def get_api_fields(self, model): api_fields = ['title', 'tags'] api_fields.extend(super(DocumentsAPIEndpoint, self).get_api_fields(model)) @@ -547,14 +416,8 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): # Check query paramters self.check_query_parameters(request, queryset) - # Filtering - queryset = self.do_field_filtering(request, queryset) - - # Ordering - queryset = self.do_ordering(request, queryset) - - # Search - queryset = self.do_search(request, queryset) + # Filtering, Ordering, Search. + queryset = self.filter_queryset(queryset) # Pagination total_count = queryset.count() diff --git a/wagtail/contrib/wagtailapi/filters.py b/wagtail/contrib/wagtailapi/filters.py new file mode 100644 index 0000000000..6337698647 --- /dev/null +++ b/wagtail/contrib/wagtailapi/filters.py @@ -0,0 +1,148 @@ +from django.conf import settings + +from rest_framework.filters import BaseFilterBackend + +from taggit.managers import _TaggableManager + +from wagtail.wagtailcore.models import Page +from wagtail.wagtailsearch.backends import get_search_backend + +from .utils import BadRequestError + + +class FieldsFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + """ + This performs field level filtering on the result set + Eg: ?title=James Joyce + """ + fields = set(view.get_api_fields(queryset.model)).union({'id'}) + + for field_name, value in request.GET.items(): + if field_name in fields: + field = getattr(queryset.model, field_name, None) + + if isinstance(field, _TaggableManager): + for tag in value.split(','): + queryset = queryset.filter(**{field_name + '__name': tag}) + + # Stick a message on the queryset to indicate that tag filtering has been performed + # This will let the do_search method know that it must raise an error as searching + # and tag filtering at the same time is not supported + queryset._filtered_by_tag = True + else: + queryset = queryset.filter(**{field_name: value}) + + return queryset + + +class OrderingFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + """ + This applies ordering to the result set + Eg: ?order=title + + It also supports reverse ordering + Eg: ?order=-title + + And random ordering + Eg: ?order=random + """ + if 'order' in request.GET: + # Prevent ordering while searching + if 'search' in request.GET: + raise BadRequestError("ordering with a search query is not supported") + + order_by = request.GET['order'] + + # Random ordering + if order_by == 'random': + # Prevent ordering by random with offset + if 'offset' in request.GET: + raise BadRequestError("random ordering with offset is not supported") + + return queryset.order_by('?') + + # Check if reverse ordering is set + if order_by.startswith('-'): + reverse_order = True + order_by = order_by[1:] + else: + reverse_order = False + + # Add ordering + if order_by == 'id' or order_by in view.get_api_fields(queryset.model): + queryset = queryset.order_by(order_by) + else: + # Unknown field + raise BadRequestError("cannot order by '%s' (unknown field)" % order_by) + + # Reverse order + if reverse_order: + queryset = queryset.reverse() + + return queryset + + +class SearchFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + """ + This performs a full-text search on the result set + Eg: ?search=James Joyce + """ + search_enabled = getattr(settings, 'WAGTAILAPI_SEARCH_ENABLED', True) + + if 'search' in request.GET: + if not search_enabled: + raise BadRequestError("search is disabled") + + # Searching and filtering by tag at the same time is not supported + if getattr(queryset, '_filtered_by_tag', False): + raise BadRequestError("filtering by tag with a search query is not supported") + + search_query = request.GET['search'] + + sb = get_search_backend() + queryset = sb.search(search_query, queryset) + + return queryset + + +class ChildOfFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + if 'child_of' in request.GET: + try: + parent_page_id = int(request.GET['child_of']) + assert parent_page_id >= 0 + except (ValueError, AssertionError): + raise BadRequestError("child_of must be a positive integer") + + try: + parent_page = view.get_queryset(request).get(id=parent_page_id) + queryset = queryset.child_of(parent_page) + queryset._filtered_by_child_of = True + return queryset + except Page.DoesNotExist: + raise BadRequestError("parent page doesn't exist") + + return queryset + + +class DescendantOfFilter(BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + if 'descendant_of' in request.GET: + if getattr(queryset, '_filtered_by_child_of', False): + raise BadRequestError("filtering by descendant_of with child_of is not supported") + try: + ancestor_page_id = int(request.GET['descendant_of']) + assert ancestor_page_id >= 0 + except (ValueError, AssertionError): + raise BadRequestError("descendant_of must be a positive integer") + + try: + ancestor_page = view.get_queryset(request).get(id=ancestor_page_id) + return queryset.descendant_of(ancestor_page) + except Page.DoesNotExist: + raise BadRequestError("ancestor page doesn't exist") + + return queryset From 3122d19a660a80a315d4bf847070c9eb8a2f1309 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 16:36:47 +0100 Subject: [PATCH 05/22] Refactor get_api_fields --- wagtail/contrib/wagtailapi/endpoints.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 594465eaf8..1117e3a25b 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -92,6 +92,7 @@ class BaseAPIEndpoint(GenericViewSet): 'order', 'search', ]) + extra_api_fields = [] def handle_exception(self, exc): if isinstance(exc, Http404): @@ -113,7 +114,7 @@ class BaseAPIEndpoint(GenericViewSet): This returns a list of field names that are allowed to be used in the API (excluding the id field). """ - api_fields = [] + api_fields = self.extra_api_fields[:] if hasattr(model, 'api_fields'): api_fields.extend(model.api_fields) @@ -238,6 +239,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): 'child_of', 'descendant_of', ]) + extra_api_fields = ['title'] filter_backends = [ FieldsFilter, ChildOfFilter, DescendantOfFilter, OrderingFilter, SearchFilter @@ -252,11 +254,6 @@ class PagesAPIEndpoint(BaseAPIEndpoint): return queryset - def get_api_fields(self, model): - api_fields = ['title'] - api_fields.extend(super(PagesAPIEndpoint, self).get_api_fields(model)) - return api_fields - def serialize_object_metadata(self, request, page, show_details=False): data = super(PagesAPIEndpoint, self).serialize_object_metadata(request, page, show_details=show_details) @@ -345,15 +342,11 @@ class PagesAPIEndpoint(BaseAPIEndpoint): class ImagesAPIEndpoint(BaseAPIEndpoint): model = get_image_model() filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] + extra_api_fields = ['title', 'tags', 'width', 'height'] def get_queryset(self, request): return self.model.objects.all().order_by('id') - def get_api_fields(self, model): - api_fields = ['title', 'tags', 'width', 'height'] - api_fields.extend(super(ImagesAPIEndpoint, self).get_api_fields(model)) - return api_fields - def listing_view(self, request): queryset = self.get_queryset(request) @@ -395,11 +388,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): class DocumentsAPIEndpoint(BaseAPIEndpoint): filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] - - def get_api_fields(self, model): - api_fields = ['title', 'tags'] - api_fields.extend(super(DocumentsAPIEndpoint, self).get_api_fields(model)) - return api_fields + extra_api_fields = ['title', 'tags'] def serialize_object_metadata(self, request, document, show_details=False): data = super(DocumentsAPIEndpoint, self).serialize_object_metadata(request, document, show_details=show_details) From 76de8eab349722d53edf7a1fd42c1bf7db3e795f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 17:03:07 +0100 Subject: [PATCH 06/22] Refactor pagination --- wagtail/contrib/wagtailapi/endpoints.py | 87 +++++++----------------- wagtail/contrib/wagtailapi/pagination.py | 45 ++++++++++++ 2 files changed, 68 insertions(+), 64 deletions(-) create mode 100644 wagtail/contrib/wagtailapi/pagination.py diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 1117e3a25b..105cf2c879 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -27,6 +27,7 @@ from .filters import ( ChildOfFilter, DescendantOfFilter ) from .renderers import WagtailJSONRenderer +from .pagination import WagtailPagination from .utils import BadRequestError, URLPath, ObjectDetailURL @@ -83,6 +84,7 @@ def get_api_data(obj, fields): class BaseAPIEndpoint(GenericViewSet): renderer_classes = [WagtailJSONRenderer] + pagination_class = WagtailPagination filter_classes = [] known_query_parameters = frozenset([ @@ -179,34 +181,6 @@ class BaseAPIEndpoint(GenericViewSet): if unknown_parameters: raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters))) - def do_pagination(self, request, queryset): - """ - This performs limit/offset based pagination on the result set - Eg: ?limit=10&offset=20 -- Returns 10 items starting at item 20 - """ - limit_max = getattr(settings, 'WAGTAILAPI_LIMIT_MAX', 20) - - try: - offset = int(request.GET.get('offset', 0)) - assert offset >= 0 - except (ValueError, AssertionError): - raise BadRequestError("offset must be a positive integer") - - try: - limit = int(request.GET.get('limit', min(20, limit_max))) - - if limit > limit_max: - raise BadRequestError("limit cannot be higher than %d" % limit_max) - - assert limit >= 0 - except (ValueError, AssertionError): - raise BadRequestError("limit must be a positive integer") - - start = offset - stop = offset + limit - - return queryset[start:stop] - @classmethod def get_urlpatterns(cls): """ @@ -234,6 +208,7 @@ class BaseAPIEndpoint(GenericViewSet): class PagesAPIEndpoint(BaseAPIEndpoint): + name = 'pages' known_query_parameters = BaseAPIEndpoint.known_query_parameters.union([ 'type', 'child_of', @@ -310,8 +285,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): queryset = self.filter_queryset(queryset) # Pagination - total_count = queryset.count() - queryset = self.do_pagination(request, queryset) + queryset = self.paginate_queryset(queryset) # Get list of fields to show in results if 'fields' in request.GET: @@ -319,16 +293,11 @@ class PagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - data = OrderedDict([ - ('meta', OrderedDict([ - ('total_count', total_count), - ])), - ('pages', [ - self.serialize_object(request, page, fields=fields) - for page in queryset - ]), - ]) - return Response(data) + data = [ + self.serialize_object(request, page, fields=fields) + for page in queryset + ] + return self.get_paginated_response(data) def detail_view(self, request, pk): page = get_object_or_404(self.get_queryset(request), pk=pk).specific @@ -340,6 +309,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): class ImagesAPIEndpoint(BaseAPIEndpoint): + name = 'images' model = get_image_model() filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags', 'width', 'height'] @@ -357,8 +327,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): queryset = self.filter_queryset(queryset) # Pagination - total_count = queryset.count() - queryset = self.do_pagination(request, queryset) + queryset = self.paginate_queryset(queryset) # Get list of fields to show in results if 'fields' in request.GET: @@ -366,16 +335,11 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - data = OrderedDict([ - ('meta', OrderedDict([ - ('total_count', total_count), - ])), - ('images', [ - self.serialize_object(request, image, fields=fields) - for image in queryset - ]), - ]) - return Response(data) + data = [ + self.serialize_object(request, image, fields=fields) + for image in queryset + ] + return self.get_paginated_response(data) def detail_view(self, request, pk): image = get_object_or_404(self.get_queryset(request), pk=pk) @@ -387,6 +351,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): class DocumentsAPIEndpoint(BaseAPIEndpoint): + name = 'documents' filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags'] @@ -409,8 +374,7 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): queryset = self.filter_queryset(queryset) # Pagination - total_count = queryset.count() - queryset = self.do_pagination(request, queryset) + queryset = self.paginate_queryset(queryset) # Get list of fields to show in results if 'fields' in request.GET: @@ -418,16 +382,11 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): else: fields = {'title'} - data = OrderedDict([ - ('meta', OrderedDict([ - ('total_count', total_count), - ])), - ('documents', [ - self.serialize_object(request, document, fields=fields) - for document in queryset - ]), - ]) - return Response(data) + data = [ + self.serialize_object(request, document, fields=fields) + for document in queryset + ] + return self.get_paginated_response(data) def detail_view(self, request, pk): document = get_object_or_404(Document, pk=pk) diff --git a/wagtail/contrib/wagtailapi/pagination.py b/wagtail/contrib/wagtailapi/pagination.py new file mode 100644 index 0000000000..6cb470e063 --- /dev/null +++ b/wagtail/contrib/wagtailapi/pagination.py @@ -0,0 +1,45 @@ +from collections import OrderedDict + +from django.conf import settings + +from rest_framework.pagination import BasePagination +from rest_framework.response import Response + +from .utils import BadRequestError + + +class WagtailPagination(BasePagination): + def paginate_queryset(self, queryset, request, view=None): + limit_max = getattr(settings, 'WAGTAILAPI_LIMIT_MAX', 20) + + try: + offset = int(request.GET.get('offset', 0)) + assert offset >= 0 + except (ValueError, AssertionError): + raise BadRequestError("offset must be a positive integer") + + try: + limit = int(request.GET.get('limit', min(20, limit_max))) + + if limit > limit_max: + raise BadRequestError("limit cannot be higher than %d" % limit_max) + + assert limit >= 0 + except (ValueError, AssertionError): + raise BadRequestError("limit must be a positive integer") + + start = offset + stop = offset + limit + + self.view = view + self.total_count = queryset.count() + return queryset[start:stop] + + def get_paginated_response(self, data): + data = OrderedDict([ + ('meta', OrderedDict([ + ('total_count', self.total_count), + ])), + (self.view.name, data), + ]) + return Response(data) From 067247d2a40c5891bdd5c144b5878c5436884786 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 20 Jul 2015 17:30:35 +0100 Subject: [PATCH 07/22] Refactor get_fields --- wagtail/contrib/wagtailapi/endpoints.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 105cf2c879..1a3a7e34d4 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -181,6 +181,15 @@ class BaseAPIEndpoint(GenericViewSet): if unknown_parameters: raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters))) + def get_fields(self, request): + """ + Return the set of fields that should be returned in the output + representation for listing views. + """ + if 'fields' in request.GET: + return set(request.GET['fields'].split(',')) + return {'title'} + @classmethod def get_urlpatterns(cls): """ @@ -288,10 +297,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): queryset = self.paginate_queryset(queryset) # Get list of fields to show in results - if 'fields' in request.GET: - fields = set(request.GET['fields'].split(',')) - else: - fields = {'title'} + fields = self.get_fields(request) data = [ self.serialize_object(request, page, fields=fields) @@ -330,10 +336,7 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): queryset = self.paginate_queryset(queryset) # Get list of fields to show in results - if 'fields' in request.GET: - fields = set(request.GET['fields'].split(',')) - else: - fields = {'title'} + fields = self.get_fields(request) data = [ self.serialize_object(request, image, fields=fields) @@ -377,10 +380,7 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): queryset = self.paginate_queryset(queryset) # Get list of fields to show in results - if 'fields' in request.GET: - fields = set(request.GET['fields'].split(',')) - else: - fields = {'title'} + fields = self.get_fields(request) data = [ self.serialize_object(request, document, fields=fields) From 7d01beffffbb64c3f921c4e2a9dba6dbf6926379 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 11:35:46 +0100 Subject: [PATCH 08/22] Refactor to use serializers --- wagtail/contrib/wagtailapi/endpoints.py | 207 ++++------------------ wagtail/contrib/wagtailapi/serializers.py | 172 ++++++++++++++++++ 2 files changed, 206 insertions(+), 173 deletions(-) create mode 100644 wagtail/contrib/wagtailapi/serializers.py diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 1a3a7e34d4..669569e03f 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -2,11 +2,7 @@ from __future__ import absolute_import from collections import OrderedDict -from modelcluster.models import get_all_child_relations -from taggit.managers import _TaggableManager - from django.db import models -from django.utils.encoding import force_text from django.shortcuts import get_object_or_404 from django.conf.urls import url from django.conf import settings @@ -20,7 +16,6 @@ from wagtail.wagtailcore.models import Page from wagtail.wagtailimages.models import get_image_model from wagtail.wagtaildocs.models import Document from wagtail.wagtailcore.utils import resolve_model_string -from wagtail.utils.compat import get_related_model from .filters import ( FieldsFilter, OrderingFilter, SearchFilter, @@ -28,63 +23,14 @@ from .filters import ( ) from .renderers import WagtailJSONRenderer from .pagination import WagtailPagination +from .serializers import WagtailSerializer, PageSerializer, DocumentSerializer from .utils import BadRequestError, URLPath, ObjectDetailURL -def get_api_data(obj, fields): - # Find any child relations (pages only) - child_relations = {} - if isinstance(obj, Page): - child_relations = { - child_relation.field.rel.related_name: get_related_model(child_relation) - for child_relation in get_all_child_relations(type(obj)) - } - - # Loop through fields - for field_name in fields: - # Check child relations - if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'): - yield field_name, [ - dict(get_api_data(child_object, child_relations[field_name].api_fields)) - for child_object in getattr(obj, field_name).all() - ] - continue - - # Check django fields - try: - field = obj._meta.get_field(field_name) - - if field.rel and isinstance(field.rel, models.ManyToOneRel): - # Foreign key - val = field._get_val_from_obj(obj) - - if val: - yield field_name, OrderedDict([ - ('id', field._get_val_from_obj(obj)), - ('meta', OrderedDict([ - ('type', field.rel.to._meta.app_label + '.' + field.rel.to.__name__), - ('detail_url', ObjectDetailURL(field.rel.to, val)), - ])), - ]) - else: - yield field_name, None - else: - yield field_name, field._get_val_from_obj(obj) - - continue - except models.fields.FieldDoesNotExist: - pass - - # Check attributes - if hasattr(obj, field_name): - value = getattr(obj, field_name) - yield field_name, force_text(value, strings_only=True) - continue - - class BaseAPIEndpoint(GenericViewSet): renderer_classes = [WagtailJSONRenderer] pagination_class = WagtailPagination + serializer_class = WagtailSerializer filter_classes = [] known_query_parameters = frozenset([ @@ -123,56 +69,10 @@ class BaseAPIEndpoint(GenericViewSet): return api_fields - def serialize_object_metadata(self, request, obj, show_details=False): - """ - This returns a JSON-serialisable dict to use for the "meta" - section of a particlular object. - """ - data = OrderedDict() - - # Add type - data['type'] = type(obj)._meta.app_label + '.' + type(obj).__name__ - data['detail_url'] = ObjectDetailURL(type(obj), obj.pk) - - return data - - def serialize_object(self, request, obj, fields=frozenset(), extra_data=(), all_fields=False, show_details=False): - """ - This converts an object into JSON-serialisable dict so it can - be used in the API. - """ - data = [ - ('id', obj.id), - ] - - # Add meta - metadata = self.serialize_object_metadata(request, obj, show_details=show_details) - if metadata: - data.append(('meta', metadata)) - - # Add extra data - data.extend(extra_data) - - # Add other fields - api_fields = self.get_api_fields(type(obj)) - api_fields = list(OrderedDict.fromkeys(api_fields)) # Removes any duplicates in case the user put "title" in api_fields - - if all_fields: - fields = api_fields - else: - unknown_fields = fields - set(api_fields) - - if unknown_fields: - raise BadRequestError("unknown fields: %s" % ', '.join(sorted(unknown_fields))) - - # Reorder fields so it matches the order of api_fields - fields = [field for field in api_fields if field in fields] - - data.extend(get_api_data(obj, fields)) - - return OrderedDict(data) - def check_query_parameters(self, request, queryset): + """ + Ensure that only valid query paramters are included in the URL. + """ query_parameters = set(request.GET.keys()) # All query paramters must be either a field or an operation @@ -190,6 +90,21 @@ class BaseAPIEndpoint(GenericViewSet): return set(request.GET['fields'].split(',')) return {'title'} + def get_serializer_context(self): + request = self.request + if self.action == 'listing_view': + return { + 'request': request, + 'view': self, + 'fields': self.get_fields(request) + } + return { + 'request': request, + 'view': self, + 'all_fields': True, + 'show_details': True + } + @classmethod def get_urlpatterns(cls): """ @@ -228,6 +143,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): FieldsFilter, ChildOfFilter, DescendantOfFilter, OrderingFilter, SearchFilter ] + serializer_class = PageSerializer def get_queryset(self, request, model=Page): # Get live pages that are not in a private section @@ -238,35 +154,6 @@ class PagesAPIEndpoint(BaseAPIEndpoint): return queryset - def serialize_object_metadata(self, request, page, show_details=False): - data = super(PagesAPIEndpoint, self).serialize_object_metadata(request, page, show_details=show_details) - - # Add type - data['type'] = page.specific_class._meta.app_label + '.' + page.specific_class.__name__ - - return data - - def serialize_object(self, request, page, fields=frozenset(), extra_data=(), all_fields=False, show_details=False): - # Add parent - if show_details: - parent = page.get_parent() - - # Make sure the parent is visible in the API - if self.get_queryset(request).filter(id=parent.id).exists(): - parent_class = parent.specific_class - - extra_data += ( - ('parent', OrderedDict([ - ('id', parent.id), - ('meta', OrderedDict([ - ('type', parent_class._meta.app_label + '.' + parent_class.__name__), - ('detail_url', ObjectDetailURL(parent_class, parent.id)), - ])), - ])), - ) - - return super(PagesAPIEndpoint, self).serialize_object(request, page, fields=fields, extra_data=extra_data, all_fields=all_fields, show_details=show_details) - def get_model(self, request): if 'type' not in request.GET: return Page @@ -296,19 +183,13 @@ class PagesAPIEndpoint(BaseAPIEndpoint): # Pagination queryset = self.paginate_queryset(queryset) - # Get list of fields to show in results - fields = self.get_fields(request) - - data = [ - self.serialize_object(request, page, fields=fields) - for page in queryset - ] - return self.get_paginated_response(data) + serializer = self.get_serializer(queryset, many=True) + return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): page = get_object_or_404(self.get_queryset(request), pk=pk).specific - data = self.serialize_object(request, page, all_fields=True, show_details=True) - return Response(data) + serializer = self.get_serializer(page) + return Response(serializer.data) def has_model(self, model): return issubclass(model, Page) @@ -335,19 +216,13 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): # Pagination queryset = self.paginate_queryset(queryset) - # Get list of fields to show in results - fields = self.get_fields(request) - - data = [ - self.serialize_object(request, image, fields=fields) - for image in queryset - ] - return self.get_paginated_response(data) + serializer = self.get_serializer(queryset, many=True) + return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): image = get_object_or_404(self.get_queryset(request), pk=pk) - data = self.serialize_object(request, image, all_fields=True) - return Response(data) + serializer = self.get_serializer(image) + return Response(serializer.data) def has_model(self, model): return model == self.model @@ -357,15 +232,7 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): name = 'documents' filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags'] - - def serialize_object_metadata(self, request, document, show_details=False): - data = super(DocumentsAPIEndpoint, self).serialize_object_metadata(request, document, show_details=show_details) - - # Download URL - if show_details: - data['download_url'] = URLPath(document.url) - - return data + serializer_class = DocumentSerializer def listing_view(self, request): queryset = Document.objects.all().order_by('id') @@ -379,19 +246,13 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): # Pagination queryset = self.paginate_queryset(queryset) - # Get list of fields to show in results - fields = self.get_fields(request) - - data = [ - self.serialize_object(request, document, fields=fields) - for document in queryset - ] - return self.get_paginated_response(data) + serializer = self.get_serializer(queryset, many=True) + return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): document = get_object_or_404(Document, pk=pk) - data = self.serialize_object(request, document, all_fields=True, show_details=True) - return Response(data) + serializer = self.get_serializer(document) + return Response(serializer.data) def has_model(self, model): return model == Document diff --git a/wagtail/contrib/wagtailapi/serializers.py b/wagtail/contrib/wagtailapi/serializers.py new file mode 100644 index 0000000000..3b47ade507 --- /dev/null +++ b/wagtail/contrib/wagtailapi/serializers.py @@ -0,0 +1,172 @@ +from __future__ import absolute_import + +from collections import OrderedDict + +from django.db import models +from django.utils.encoding import force_text + +from modelcluster.models import get_all_child_relations + +from rest_framework.serializers import BaseSerializer + +from wagtail.utils.compat import get_related_model +from wagtail.wagtailcore.models import Page + +from .utils import ObjectDetailURL, URLPath, BadRequestError + + +def get_api_data(obj, fields): + # Find any child relations (pages only) + child_relations = {} + if isinstance(obj, Page): + child_relations = { + child_relation.field.rel.related_name: get_related_model(child_relation) + for child_relation in get_all_child_relations(type(obj)) + } + + # Loop through fields + for field_name in fields: + # Check child relations + if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'): + yield field_name, [ + dict(get_api_data(child_object, child_relations[field_name].api_fields)) + for child_object in getattr(obj, field_name).all() + ] + continue + + # Check django fields + try: + field = obj._meta.get_field(field_name) + + if field.rel and isinstance(field.rel, models.ManyToOneRel): + # Foreign key + val = field._get_val_from_obj(obj) + + if val: + yield field_name, OrderedDict([ + ('id', field._get_val_from_obj(obj)), + ('meta', OrderedDict([ + ('type', field.rel.to._meta.app_label + '.' + field.rel.to.__name__), + ('detail_url', ObjectDetailURL(field.rel.to, val)), + ])), + ]) + else: + yield field_name, None + else: + yield field_name, field._get_val_from_obj(obj) + + continue + except models.fields.FieldDoesNotExist: + pass + + # Check attributes + if hasattr(obj, field_name): + value = getattr(obj, field_name) + yield field_name, force_text(value, strings_only=True) + continue + + +class WagtailSerializer(BaseSerializer): + def to_representation(self, instance): + request = self.context['request'] + fields = self.context.get('fields', frozenset()) + all_fields = self.context.get('all_fields', False) + show_details = self.context.get('show_details', False) + return self.serialize_object( + request, + instance, + fields=fields, + all_fields=all_fields, + show_details=show_details + ) + + def serialize_object_metadata(self, request, obj, show_details=False): + """ + This returns a JSON-serialisable dict to use for the "meta" + section of a particlular object. + """ + data = OrderedDict() + + # Add type + data['type'] = type(obj)._meta.app_label + '.' + type(obj).__name__ + data['detail_url'] = ObjectDetailURL(type(obj), obj.pk) + + return data + + def serialize_object(self, request, obj, fields=frozenset(), extra_data=(), all_fields=False, show_details=False): + """ + This converts an object into JSON-serialisable dict so it can + be used in the API. + """ + data = [ + ('id', obj.id), + ] + + # Add meta + metadata = self.serialize_object_metadata(request, obj, show_details=show_details) + if metadata: + data.append(('meta', metadata)) + + # Add extra data + data.extend(extra_data) + + # Add other fields + api_fields = self.context['view'].get_api_fields(type(obj)) + api_fields = list(OrderedDict.fromkeys(api_fields)) # Removes any duplicates in case the user put "title" in api_fields + + if all_fields: + fields = api_fields + else: + unknown_fields = fields - set(api_fields) + + if unknown_fields: + raise BadRequestError("unknown fields: %s" % ', '.join(sorted(unknown_fields))) + + # Reorder fields so it matches the order of api_fields + fields = [field for field in api_fields if field in fields] + + data.extend(get_api_data(obj, fields)) + + return OrderedDict(data) + + +class PageSerializer(WagtailSerializer): + def serialize_object_metadata(self, request, page, show_details=False): + data = super(PageSerializer, self).serialize_object_metadata(request, page, show_details=show_details) + + # Add type + data['type'] = page.specific_class._meta.app_label + '.' + page.specific_class.__name__ + + return data + + def serialize_object(self, request, page, fields=frozenset(), extra_data=(), all_fields=False, show_details=False): + # Add parent + if show_details: + parent = page.get_parent() + + # Make sure the parent is visible in the API + if self.context['view'].get_queryset(request).filter(id=parent.id).exists(): + parent_class = parent.specific_class + + extra_data += ( + ('parent', OrderedDict([ + ('id', parent.id), + ('meta', OrderedDict([ + ('type', parent_class._meta.app_label + '.' + parent_class.__name__), + ('detail_url', ObjectDetailURL(parent_class, parent.id)), + ])), + ])), + ) + + return super(PageSerializer, self).serialize_object(request, page, fields=fields, extra_data=extra_data, all_fields=all_fields, show_details=show_details) + + +class DocumentSerializer(WagtailSerializer): + def serialize_object_metadata(self, request, document, show_details=False): + data = super(DocumentSerializer, self).serialize_object_metadata(request, document, show_details=show_details) + + # Download URL + if show_details: + data['download_url'] = URLPath(document.url) + + return data From 86e1a60ad2bf2ef4c14b503f2f9e2138c9dc4f8d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 11:43:19 +0100 Subject: [PATCH 09/22] Refactor find_model_detail_view --- wagtail/contrib/wagtailapi/endpoints.py | 28 +++++++++---------------- wagtail/contrib/wagtailapi/renderers.py | 10 ++++++++- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 669569e03f..469691df62 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -115,20 +115,9 @@ class BaseAPIEndpoint(GenericViewSet): url(r'^(\d+)/$', cls.as_view({'get': 'detail_view'}), name='detail'), ] - def find_model_detail_view(self, model): - # TODO: Needs refactoring. This is currently duplicated, and also - # does a bit of a dance around instantiating these classes. - endpoints = { - 'pages': PagesAPIEndpoint(), - 'images': ImagesAPIEndpoint(), - 'documents': DocumentsAPIEndpoint(), - } - for endpoint_name, endpoint in endpoints.items(): - if endpoint.has_model(model): - return 'wagtailapi_v1:%s:detail' % endpoint_name - - def has_model(self, model): - return False + @classmethod + def has_model(cls, model): + return NotImplemented class PagesAPIEndpoint(BaseAPIEndpoint): @@ -191,7 +180,8 @@ class PagesAPIEndpoint(BaseAPIEndpoint): serializer = self.get_serializer(page) return Response(serializer.data) - def has_model(self, model): + @classmethod + def has_model(cls, model): return issubclass(model, Page) @@ -224,8 +214,9 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): serializer = self.get_serializer(image) return Response(serializer.data) - def has_model(self, model): - return model == self.model + @classmethod + def has_model(cls, model): + return model == cls.model class DocumentsAPIEndpoint(BaseAPIEndpoint): @@ -254,5 +245,6 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): serializer = self.get_serializer(document) return Response(serializer.data) - def has_model(self, model): + @classmethod + def has_model(cls, model): return model == Document diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py index c58ea32c63..25031a93dd 100644 --- a/wagtail/contrib/wagtailapi/renderers.py +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -18,6 +18,14 @@ def get_full_url(request, path): return base_url + path +def find_model_detail_view(model): + from .endpoints import PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint + + for endpoint in [PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint]: + if endpoint.has_model(model): + return 'wagtailapi_v1:%s:detail' % endpoint.name + + class WagtailJSONRenderer(renderers.BaseRenderer): media_type = 'application/json' charset = None @@ -35,7 +43,7 @@ class WagtailJSONRenderer(renderers.BaseRenderer): elif isinstance(o, URLPath): return get_full_url(request, o.path) elif isinstance(o, ObjectDetailURL): - view = endpoint.find_model_detail_view(o.model) + view = find_model_detail_view(o.model) if view: return get_full_url(request, reverse(view, args=(o.pk, ))) From 6ed50e18c488e283110d3fc9b8c456f720cb7bff Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 12:11:26 +0100 Subject: [PATCH 10/22] Refactor filters to not callback into views --- wagtail/contrib/wagtailapi/endpoints.py | 3 +++ wagtail/contrib/wagtailapi/filters.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 469691df62..3164ec62f6 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -91,6 +91,9 @@ class BaseAPIEndpoint(GenericViewSet): return {'title'} def get_serializer_context(self): + """ + The serialization context differs between listing and detail views. + """ request = self.request if self.action == 'listing_view': return { diff --git a/wagtail/contrib/wagtailapi/filters.py b/wagtail/contrib/wagtailapi/filters.py index 6337698647..d6c52cfa02 100644 --- a/wagtail/contrib/wagtailapi/filters.py +++ b/wagtail/contrib/wagtailapi/filters.py @@ -117,8 +117,13 @@ class ChildOfFilter(BaseFilterBackend): except (ValueError, AssertionError): raise BadRequestError("child_of must be a positive integer") + # Get live pages that are not in a private section + pages = Page.objects.public().live() + # Filter by site + pages = pages.descendant_of(request.site.root_page, inclusive=True) + try: - parent_page = view.get_queryset(request).get(id=parent_page_id) + parent_page = pages.get(id=parent_page_id) queryset = queryset.child_of(parent_page) queryset._filtered_by_child_of = True return queryset @@ -139,8 +144,13 @@ class DescendantOfFilter(BaseFilterBackend): except (ValueError, AssertionError): raise BadRequestError("descendant_of must be a positive integer") + # Get live pages that are not in a private section + pages = Page.objects.public().live() + # Filter by site + pages = pages.descendant_of(request.site.root_page, inclusive=True) + try: - ancestor_page = view.get_queryset(request).get(id=ancestor_page_id) + ancestor_page = pages.get(id=ancestor_page_id) return queryset.descendant_of(ancestor_page) except Page.DoesNotExist: raise BadRequestError("ancestor page doesn't exist") From 67214c002e8c9fb5cc760644004512a539cfdbf7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 12:44:11 +0100 Subject: [PATCH 11/22] Refactor away .get_model --- wagtail/contrib/wagtailapi/endpoints.py | 31 +++++++++++-------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 3164ec62f6..0652cb3a94 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -137,7 +137,18 @@ class PagesAPIEndpoint(BaseAPIEndpoint): ] serializer_class = PageSerializer - def get_queryset(self, request, model=Page): + def get_queryset(self, request): + if 'type' not in request.GET: + model = Page + else: + model_name = request.GET['type'] + try: + model = resolve_model_string(model_name) + except LookupError: + raise BadRequestError("type doesn't exist") + if not issubclass(model, Page): + raise BadRequestError("type doesn't exist") + # Get live pages that are not in a private section queryset = model.objects.public().live() @@ -146,25 +157,9 @@ class PagesAPIEndpoint(BaseAPIEndpoint): return queryset - def get_model(self, request): - if 'type' not in request.GET: - return Page - - model_name = request.GET['type'] - try: - model = resolve_model_string(model_name) - - if not issubclass(model, Page): - raise BadRequestError("type doesn't exist") - - return model - except LookupError: - raise BadRequestError("type doesn't exist") - def listing_view(self, request): # Get model and queryset - model = self.get_model(request) - queryset = self.get_queryset(request, model=model) + queryset = self.get_queryset(request) # Check query paramters self.check_query_parameters(request, queryset) From 2ce5db302a91fb2df4e7a580c1e4765716c9a419 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 12:54:53 +0100 Subject: [PATCH 12/22] Use standard REST framework get_object, get_queryset signatures --- wagtail/contrib/wagtailapi/endpoints.py | 35 ++++++++++++----------- wagtail/contrib/wagtailapi/serializers.py | 7 ++++- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 0652cb3a94..ca388e9959 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -69,11 +69,11 @@ class BaseAPIEndpoint(GenericViewSet): return api_fields - def check_query_parameters(self, request, queryset): + def check_query_parameters(self, queryset): """ Ensure that only valid query paramters are included in the URL. """ - query_parameters = set(request.GET.keys()) + query_parameters = set(self.request.GET.keys()) # All query paramters must be either a field or an operation allowed_query_parameters = set(self.get_api_fields(queryset.model)).union(self.known_query_parameters).union({'id'}) @@ -115,7 +115,7 @@ class BaseAPIEndpoint(GenericViewSet): """ return [ url(r'^$', cls.as_view({'get': 'listing_view'}), name='listing'), - url(r'^(\d+)/$', cls.as_view({'get': 'detail_view'}), name='detail'), + url(r'^(?P\d+)/$', cls.as_view({'get': 'detail_view'}), name='detail'), ] @classmethod @@ -137,7 +137,9 @@ class PagesAPIEndpoint(BaseAPIEndpoint): ] serializer_class = PageSerializer - def get_queryset(self, request): + def get_queryset(self): + request = self.request + if 'type' not in request.GET: model = Page else: @@ -159,10 +161,10 @@ class PagesAPIEndpoint(BaseAPIEndpoint): def listing_view(self, request): # Get model and queryset - queryset = self.get_queryset(request) + queryset = self.get_queryset() # Check query paramters - self.check_query_parameters(request, queryset) + self.check_query_parameters(queryset) # Filtering, Ancestor/Descendant, Ordering, Search. queryset = self.filter_queryset(queryset) @@ -174,7 +176,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): - page = get_object_or_404(self.get_queryset(request), pk=pk).specific + page = self.get_object().specific serializer = self.get_serializer(page) return Response(serializer.data) @@ -188,15 +190,13 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): model = get_image_model() filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags', 'width', 'height'] - - def get_queryset(self, request): - return self.model.objects.all().order_by('id') + queryset = get_image_model().objects.all().order_by('id') def listing_view(self, request): - queryset = self.get_queryset(request) + queryset = self.get_queryset() # Check query paramters - self.check_query_parameters(request, queryset) + self.check_query_parameters(queryset) # Filtering, Ordering, Search. queryset = self.filter_queryset(queryset) @@ -208,13 +208,13 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): - image = get_object_or_404(self.get_queryset(request), pk=pk) + image = self.get_object() serializer = self.get_serializer(image) return Response(serializer.data) @classmethod def has_model(cls, model): - return model == cls.model + return model == get_image_model() class DocumentsAPIEndpoint(BaseAPIEndpoint): @@ -222,12 +222,13 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags'] serializer_class = DocumentSerializer + queryset = Document.objects.all().order_by('id') def listing_view(self, request): - queryset = Document.objects.all().order_by('id') + queryset = self.get_queryset() # Check query paramters - self.check_query_parameters(request, queryset) + self.check_query_parameters(queryset) # Filtering, Ordering, Search. queryset = self.filter_queryset(queryset) @@ -239,7 +240,7 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): - document = get_object_or_404(Document, pk=pk) + document = self.get_object() serializer = self.get_serializer(document) return Response(serializer.data) diff --git a/wagtail/contrib/wagtailapi/serializers.py b/wagtail/contrib/wagtailapi/serializers.py index 3b47ade507..62a5918611 100644 --- a/wagtail/contrib/wagtailapi/serializers.py +++ b/wagtail/contrib/wagtailapi/serializers.py @@ -145,7 +145,12 @@ class PageSerializer(WagtailSerializer): parent = page.get_parent() # Make sure the parent is visible in the API - if self.context['view'].get_queryset(request).filter(id=parent.id).exists(): + # Get live pages that are not in a private section + pages = Page.objects.public().live() + # Filter by site + pages = pages.descendant_of(request.site.root_page, inclusive=True) + + if pages.filter(id=parent.id).exists(): parent_class = parent.specific_class extra_data += ( From 82e7b79bb204383aacc44e811309d84b3c8d8103 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 12:59:56 +0100 Subject: [PATCH 13/22] Override get_object in PagesAPIEndpoint --- wagtail/contrib/wagtailapi/endpoints.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index ca388e9959..687ad5230f 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -140,6 +140,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): def get_queryset(self): request = self.request + # Allow pages to be filtered to a specific type if 'type' not in request.GET: model = Page else: @@ -159,6 +160,10 @@ class PagesAPIEndpoint(BaseAPIEndpoint): return queryset + def get_object(self): + base = super(PagesAPIEndpoint, self).get_object() + return base.specific + def listing_view(self, request): # Get model and queryset queryset = self.get_queryset() @@ -176,7 +181,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): - page = self.get_object().specific + page = self.get_object() serializer = self.get_serializer(page) return Response(serializer.data) From 8996584e38d0c6b64ecf19a70a6f571c5b7899f4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 13:07:14 +0100 Subject: [PATCH 14/22] Refactor away common methods for API endpoints --- wagtail/contrib/wagtailapi/endpoints.py | 72 ++++--------------------- 1 file changed, 9 insertions(+), 63 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 687ad5230f..833463ddb0 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -52,10 +52,17 @@ class BaseAPIEndpoint(GenericViewSet): return super(BaseAPIEndpoint, self).handle_exception(exc) def listing_view(self, request): - return NotImplemented + queryset = self.get_queryset() + self.check_query_parameters(queryset) + queryset = self.filter_queryset(queryset) + queryset = self.paginate_queryset(queryset) + serializer = self.get_serializer(queryset, many=True) + return self.get_paginated_response(serializer.data) def detail_view(self, request, pk): - return NotImplemented + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response(serializer.data) def get_api_fields(self, model): """ @@ -164,27 +171,6 @@ class PagesAPIEndpoint(BaseAPIEndpoint): base = super(PagesAPIEndpoint, self).get_object() return base.specific - def listing_view(self, request): - # Get model and queryset - queryset = self.get_queryset() - - # Check query paramters - self.check_query_parameters(queryset) - - # Filtering, Ancestor/Descendant, Ordering, Search. - queryset = self.filter_queryset(queryset) - - # Pagination - queryset = self.paginate_queryset(queryset) - - serializer = self.get_serializer(queryset, many=True) - return self.get_paginated_response(serializer.data) - - def detail_view(self, request, pk): - page = self.get_object() - serializer = self.get_serializer(page) - return Response(serializer.data) - @classmethod def has_model(cls, model): return issubclass(model, Page) @@ -197,26 +183,6 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): extra_api_fields = ['title', 'tags', 'width', 'height'] queryset = get_image_model().objects.all().order_by('id') - def listing_view(self, request): - queryset = self.get_queryset() - - # Check query paramters - self.check_query_parameters(queryset) - - # Filtering, Ordering, Search. - queryset = self.filter_queryset(queryset) - - # Pagination - queryset = self.paginate_queryset(queryset) - - serializer = self.get_serializer(queryset, many=True) - return self.get_paginated_response(serializer.data) - - def detail_view(self, request, pk): - image = self.get_object() - serializer = self.get_serializer(image) - return Response(serializer.data) - @classmethod def has_model(cls, model): return model == get_image_model() @@ -229,26 +195,6 @@ class DocumentsAPIEndpoint(BaseAPIEndpoint): serializer_class = DocumentSerializer queryset = Document.objects.all().order_by('id') - def listing_view(self, request): - queryset = self.get_queryset() - - # Check query paramters - self.check_query_parameters(queryset) - - # Filtering, Ordering, Search. - queryset = self.filter_queryset(queryset) - - # Pagination - queryset = self.paginate_queryset(queryset) - - serializer = self.get_serializer(queryset, many=True) - return self.get_paginated_response(serializer.data) - - def detail_view(self, request, pk): - document = self.get_object() - serializer = self.get_serializer(document) - return Response(serializer.data) - @classmethod def has_model(cls, model): return model == Document From f892c093086fed3bd161b29a57a6127babec729f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 13:14:02 +0100 Subject: [PATCH 15/22] More consistent reordering fields and methods --- wagtail/contrib/wagtailapi/endpoints.py | 44 +++++++++++++------------ 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 833463ddb0..28291fac4b 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -42,15 +42,6 @@ class BaseAPIEndpoint(GenericViewSet): ]) extra_api_fields = [] - def handle_exception(self, exc): - if isinstance(exc, Http404): - data = {'message': str(exc)} - return Response(data, status=status.HTTP_404_NOT_FOUND) - elif isinstance(exc, BadRequestError): - data = {'message': str(exc)} - return Response(data, status=status.HTTP_400_BAD_REQUEST) - return super(BaseAPIEndpoint, self).handle_exception(exc) - def listing_view(self, request): queryset = self.get_queryset() self.check_query_parameters(queryset) @@ -64,6 +55,15 @@ class BaseAPIEndpoint(GenericViewSet): serializer = self.get_serializer(instance) return Response(serializer.data) + def handle_exception(self, exc): + if isinstance(exc, Http404): + data = {'message': str(exc)} + return Response(data, status=status.HTTP_404_NOT_FOUND) + elif isinstance(exc, BadRequestError): + data = {'message': str(exc)} + return Response(data, status=status.HTTP_400_BAD_REQUEST) + return super(BaseAPIEndpoint, self).handle_exception(exc) + def get_api_fields(self, model): """ This returns a list of field names that are allowed to @@ -131,18 +131,21 @@ class BaseAPIEndpoint(GenericViewSet): class PagesAPIEndpoint(BaseAPIEndpoint): - name = 'pages' + serializer_class = PageSerializer + filter_backends = [ + FieldsFilter, + ChildOfFilter, + DescendantOfFilter, + OrderingFilter, + SearchFilter + ] known_query_parameters = BaseAPIEndpoint.known_query_parameters.union([ 'type', 'child_of', 'descendant_of', ]) extra_api_fields = ['title'] - filter_backends = [ - FieldsFilter, ChildOfFilter, DescendantOfFilter, - OrderingFilter, SearchFilter - ] - serializer_class = PageSerializer + name = 'pages' def get_queryset(self): request = self.request @@ -177,11 +180,10 @@ class PagesAPIEndpoint(BaseAPIEndpoint): class ImagesAPIEndpoint(BaseAPIEndpoint): - name = 'images' - model = get_image_model() + queryset = get_image_model().objects.all().order_by('id') filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags', 'width', 'height'] - queryset = get_image_model().objects.all().order_by('id') + name = 'images' @classmethod def has_model(cls, model): @@ -189,11 +191,11 @@ class ImagesAPIEndpoint(BaseAPIEndpoint): class DocumentsAPIEndpoint(BaseAPIEndpoint): - name = 'documents' + queryset = Document.objects.all().order_by('id') + serializer_class = DocumentSerializer filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags'] - serializer_class = DocumentSerializer - queryset = Document.objects.all().order_by('id') + name = 'documents' @classmethod def has_model(cls, model): From 297f64509ad6519fb2eb736f10b544869e8734af Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 13:29:25 +0100 Subject: [PATCH 16/22] Refactor away .get_fields --- wagtail/contrib/wagtailapi/endpoints.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 28291fac4b..f061d1026f 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -32,6 +32,7 @@ class BaseAPIEndpoint(GenericViewSet): pagination_class = WagtailPagination serializer_class = WagtailSerializer filter_classes = [] + queryset = None # Set on subclasses or implement `get_queryset()`. known_query_parameters = frozenset([ 'limit', @@ -41,6 +42,7 @@ class BaseAPIEndpoint(GenericViewSet): 'search', ]) extra_api_fields = [] + name = None # Set on subclass. def listing_view(self, request): queryset = self.get_queryset() @@ -88,26 +90,24 @@ class BaseAPIEndpoint(GenericViewSet): if unknown_parameters: raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters))) - def get_fields(self, request): - """ - Return the set of fields that should be returned in the output - representation for listing views. - """ - if 'fields' in request.GET: - return set(request.GET['fields'].split(',')) - return {'title'} - def get_serializer_context(self): """ The serialization context differs between listing and detail views. """ request = self.request if self.action == 'listing_view': + + if 'fields' in request.GET: + fields = set(request.GET['fields'].split(',')) + else: + fields = {'title'} + return { 'request': request, 'view': self, - 'fields': self.get_fields(request) + 'fields': fields } + return { 'request': request, 'view': self, From b5a46f3deea244bcaee3785c4c001a3616bc9c77 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 15:03:46 +0100 Subject: [PATCH 17/22] Added REST framework to tox --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 0fb33a6983..20c41ea85c 100644 --- a/tox.ini +++ b/tox.ini @@ -22,6 +22,7 @@ deps = django-taggit==0.13.0 django-treebeard==3.0 django-sendfile==0.3.6 + djangorestframework==3.1.3 Pillow>=2.3.0 beautifulsoup4>=4.3.2 html5lib==0.999 From 54d14b64bb41be6f35df3dce7a8bb8a1d99ee83e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 15:08:10 +0100 Subject: [PATCH 18/22] PEP8 fixes --- wagtail/contrib/wagtailapi/endpoints.py | 7 +------ wagtail/contrib/wagtailapi/renderers.py | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index f061d1026f..1b03c39fe2 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -1,11 +1,6 @@ from __future__ import absolute_import -from collections import OrderedDict - -from django.db import models -from django.shortcuts import get_object_or_404 from django.conf.urls import url -from django.conf import settings from django.http import Http404 from rest_framework import status @@ -24,7 +19,7 @@ from .filters import ( from .renderers import WagtailJSONRenderer from .pagination import WagtailPagination from .serializers import WagtailSerializer, PageSerializer, DocumentSerializer -from .utils import BadRequestError, URLPath, ObjectDetailURL +from .utils import BadRequestError class BaseAPIEndpoint(GenericViewSet): diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py index 25031a93dd..3f2babf882 100644 --- a/wagtail/contrib/wagtailapi/renderers.py +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -31,7 +31,6 @@ class WagtailJSONRenderer(renderers.BaseRenderer): charset = None def render(self, data, media_type=None, renderer_context=None): - endpoint = renderer_context['view'] request = renderer_context['request'] class WagtailAPIJSONEncoder(DjangoJSONEncoder): From 0d0ff6c89a570b4d41d4bca0a61f83a936c6c1c5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 17:26:49 +0100 Subject: [PATCH 19/22] PEP8 fix --- wagtail/contrib/wagtailapi/endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 1b03c39fe2..1f064f5dd0 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -175,7 +175,7 @@ class PagesAPIEndpoint(BaseAPIEndpoint): class ImagesAPIEndpoint(BaseAPIEndpoint): - queryset = get_image_model().objects.all().order_by('id') + queryset = get_image_model().objects.all().order_by('id') filter_backends = [FieldsFilter, OrderingFilter, SearchFilter] extra_api_fields = ['title', 'tags', 'width', 'height'] name = 'images' From 40db88f4e149e710e5961dc95995b83fc14c0758 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 21 Jul 2015 17:31:40 +0100 Subject: [PATCH 20/22] Resolve py3 compat issue - always return bytes from WagtailJSONRenderer --- wagtail/contrib/wagtailapi/renderers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py index 3f2babf882..640535de71 100644 --- a/wagtail/contrib/wagtailapi/renderers.py +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -2,6 +2,7 @@ import json from django.core.serializers.json import DjangoJSONEncoder from django.core.urlresolvers import reverse +from django.utils.six import text_type from rest_framework import renderers @@ -53,4 +54,9 @@ class WagtailJSONRenderer(renderers.BaseRenderer): else: return super(WagtailAPIJSONEncoder, self).default(o) - return json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder) + ret = json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder) + + # Deal with inconsistent py2/py3 behavior, and always return bytes. + if isinstance(ret, text_type): + return bytes(ret.encode('utf-8')) + return ret From 1905585e6280d012964cce9c18929b9cabdacb8f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 22 Jul 2015 09:00:40 +0100 Subject: [PATCH 21/22] pages_for_site --- wagtail/contrib/wagtailapi/filters.py | 18 ++++----------- wagtail/contrib/wagtailapi/serializers.py | 11 +++------ wagtail/contrib/wagtailapi/utils.py | 28 +++++++++++++++-------- 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/wagtail/contrib/wagtailapi/filters.py b/wagtail/contrib/wagtailapi/filters.py index d6c52cfa02..7808065738 100644 --- a/wagtail/contrib/wagtailapi/filters.py +++ b/wagtail/contrib/wagtailapi/filters.py @@ -7,7 +7,7 @@ from taggit.managers import _TaggableManager from wagtail.wagtailcore.models import Page from wagtail.wagtailsearch.backends import get_search_backend -from .utils import BadRequestError +from .utils import BadRequestError, pages_for_site class FieldsFilter(BaseFilterBackend): @@ -117,13 +117,9 @@ class ChildOfFilter(BaseFilterBackend): except (ValueError, AssertionError): raise BadRequestError("child_of must be a positive integer") - # Get live pages that are not in a private section - pages = Page.objects.public().live() - # Filter by site - pages = pages.descendant_of(request.site.root_page, inclusive=True) - + site_pages = pages_for_site(request.site) try: - parent_page = pages.get(id=parent_page_id) + parent_page = site_pages.get(id=parent_page_id) queryset = queryset.child_of(parent_page) queryset._filtered_by_child_of = True return queryset @@ -144,13 +140,9 @@ class DescendantOfFilter(BaseFilterBackend): except (ValueError, AssertionError): raise BadRequestError("descendant_of must be a positive integer") - # Get live pages that are not in a private section - pages = Page.objects.public().live() - # Filter by site - pages = pages.descendant_of(request.site.root_page, inclusive=True) - + site_pages = pages_for_site(request.site) try: - ancestor_page = pages.get(id=ancestor_page_id) + ancestor_page = site_pages.get(id=ancestor_page_id) return queryset.descendant_of(ancestor_page) except Page.DoesNotExist: raise BadRequestError("ancestor page doesn't exist") diff --git a/wagtail/contrib/wagtailapi/serializers.py b/wagtail/contrib/wagtailapi/serializers.py index 62a5918611..a28f6f2f50 100644 --- a/wagtail/contrib/wagtailapi/serializers.py +++ b/wagtail/contrib/wagtailapi/serializers.py @@ -12,7 +12,7 @@ from rest_framework.serializers import BaseSerializer from wagtail.utils.compat import get_related_model from wagtail.wagtailcore.models import Page -from .utils import ObjectDetailURL, URLPath, BadRequestError +from .utils import ObjectDetailURL, URLPath, BadRequestError, pages_for_site def get_api_data(obj, fields): @@ -144,13 +144,8 @@ class PageSerializer(WagtailSerializer): if show_details: parent = page.get_parent() - # Make sure the parent is visible in the API - # Get live pages that are not in a private section - pages = Page.objects.public().live() - # Filter by site - pages = pages.descendant_of(request.site.root_page, inclusive=True) - - if pages.filter(id=parent.id).exists(): + site_pages = pages_for_site(request.site) + if site_pages.filter(id=parent.id).exists(): parent_class = parent.specific_class extra_data += ( diff --git a/wagtail/contrib/wagtailapi/utils.py b/wagtail/contrib/wagtailapi/utils.py index 483e4f51a2..03c38ee16f 100644 --- a/wagtail/contrib/wagtailapi/utils.py +++ b/wagtail/contrib/wagtailapi/utils.py @@ -1,21 +1,13 @@ from django.conf import settings from django.utils.six.moves.urllib.parse import urlparse +from wagtail.wagtailcore.models import Page + class BadRequestError(Exception): pass -def get_base_url(request=None): - base_url = getattr(settings, 'WAGTAILAPI_BASE_URL', request.site.root_url if request else None) - - if base_url: - # We only want the scheme and netloc - base_url_parsed = urlparse(base_url) - - return base_url_parsed.scheme + '://' + base_url_parsed.netloc - - class URLPath(object): """ This class represents a URL path that should be converted to a full URL. @@ -36,3 +28,19 @@ class ObjectDetailURL(object): def __init__(self, model, pk): self.model = model self.pk = pk + + +def get_base_url(request=None): + base_url = getattr(settings, 'WAGTAILAPI_BASE_URL', request.site.root_url if request else None) + + if base_url: + # We only want the scheme and netloc + base_url_parsed = urlparse(base_url) + + return base_url_parsed.scheme + '://' + base_url_parsed.netloc + + +def pages_for_site(site): + pages = Page.objects.public().live() + pages = pages.descendant_of(site.root_page, inclusive=True) + return pages From 5cfeaa437d4dc71cca6e533344e48cdc6bbc08f9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 23 Jul 2015 09:39:12 +0100 Subject: [PATCH 22/22] Rejig find_model_detail_view for easier customization of available endpoints. --- wagtail/contrib/wagtailapi/endpoints.py | 9 +++++++++ wagtail/contrib/wagtailapi/renderers.py | 13 ++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/wagtail/contrib/wagtailapi/endpoints.py b/wagtail/contrib/wagtailapi/endpoints.py index 1f064f5dd0..a2cd7931ba 100644 --- a/wagtail/contrib/wagtailapi/endpoints.py +++ b/wagtail/contrib/wagtailapi/endpoints.py @@ -110,6 +110,15 @@ class BaseAPIEndpoint(GenericViewSet): 'show_details': True } + def get_renderer_context(self): + context = super(BaseAPIEndpoint, self).get_renderer_context() + context['endpoints'] = [ + PagesAPIEndpoint, + ImagesAPIEndpoint, + DocumentsAPIEndpoint + ] + return context + @classmethod def get_urlpatterns(cls): """ diff --git a/wagtail/contrib/wagtailapi/renderers.py b/wagtail/contrib/wagtailapi/renderers.py index 640535de71..d7f9a5e901 100644 --- a/wagtail/contrib/wagtailapi/renderers.py +++ b/wagtail/contrib/wagtailapi/renderers.py @@ -19,10 +19,8 @@ def get_full_url(request, path): return base_url + path -def find_model_detail_view(model): - from .endpoints import PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint - - for endpoint in [PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint]: +def find_model_detail_view(model, endpoints): + for endpoint in endpoints: if endpoint.has_model(model): return 'wagtailapi_v1:%s:detail' % endpoint.name @@ -33,6 +31,7 @@ class WagtailJSONRenderer(renderers.BaseRenderer): def render(self, data, media_type=None, renderer_context=None): request = renderer_context['request'] + endpoints = renderer_context['endpoints'] class WagtailAPIJSONEncoder(DjangoJSONEncoder): def default(self, o): @@ -43,10 +42,10 @@ class WagtailJSONRenderer(renderers.BaseRenderer): elif isinstance(o, URLPath): return get_full_url(request, o.path) elif isinstance(o, ObjectDetailURL): - view = find_model_detail_view(o.model) + detail_view = find_model_detail_view(o.model, endpoints) - if view: - return get_full_url(request, reverse(view, args=(o.pk, ))) + if detail_view: + return get_full_url(request, reverse(detail_view, args=(o.pk, ))) else: return None elif isinstance(o, StreamValue):