diff --git a/wagtail/api/v2/endpoints.py b/wagtail/api/v2/endpoints.py index 490410f1b5..b4d24e2492 100644 --- a/wagtail/api/v2/endpoints.py +++ b/wagtail/api/v2/endpoints.py @@ -15,7 +15,6 @@ from rest_framework.renderers import JSONRenderer, BrowsableAPIRenderer from wagtail.wagtailcore.models import Page from wagtail.wagtailimages.models import get_image_model from wagtail.wagtaildocs.models import Document -from wagtail.wagtailcore.utils import resolve_model_string from .filters import ( FieldsFilter, OrderingFilter, SearchFilter, @@ -23,7 +22,7 @@ from .filters import ( ) from .pagination import WagtailPagination from .serializers import BaseSerializer, PageSerializer, DocumentSerializer, ImageSerializer, get_serializer_class -from .utils import BadRequestError +from .utils import BadRequestError, page_models_from_string, filter_page_type class BaseAPIEndpoint(GenericViewSet): @@ -204,19 +203,24 @@ class PagesAPIEndpoint(BaseAPIEndpoint): request = self.request # Allow pages to be filtered to a specific type - if 'type' not in request.GET: - model = Page + try: + models = page_models_from_string(request.GET.get('type', 'wagtailcore.Page')) + except (LookupError, ValueError): + raise BadRequestError("type doesn't exist") + + if not models: + models = [Page] + + if len(models) == 1: + queryset = models[0].objects.all() else: - model_name = request.GET['type'] - try: - model = resolve_model_string(model_name) - except LookupError: - raise BadRequestError("type doesn't exist") - if not issubclass(model, Page): - raise BadRequestError("type doesn't exist") + queryset = Page.objects.all() + + # Filter pages by specified models + queryset = filter_page_type(queryset, models) # Get live pages that are not in a private section - queryset = model.objects.public().live() + queryset = queryset.public().live() # Filter by site queryset = queryset.descendant_of(request.site.root_page, inclusive=True) diff --git a/wagtail/api/v2/tests/test_pages.py b/wagtail/api/v2/tests/test_pages.py index 83a4f24a70..f9c04b21af 100644 --- a/wagtail/api/v2/tests/test_pages.py +++ b/wagtail/api/v2/tests/test_pages.py @@ -88,6 +88,9 @@ class TestPageListing(TestCase): for page in content['results']: self.assertEqual(page['meta']['type'], 'demosite.BlogEntryPage') + # All fields in specific type available + self.assertEqual(set(page.keys()), {'id', 'meta', 'title', 'related_links', 'date', 'body', 'tags', 'feed_image', 'carousel_items'}) + def test_type_filter_total_count(self): response = self.get_response(type='demosite.BlogEntryPage') content = json.loads(response.content.decode('UTF-8')) @@ -95,6 +98,27 @@ class TestPageListing(TestCase): # Total count must be reduced as this filters the results self.assertEqual(content['total_count'], 3) + def test_type_filter_multiple(self): + response = self.get_response(type='demosite.BlogEntryPage,demosite.EventPage') + content = json.loads(response.content.decode('UTF-8')) + + blog_page_seen = False + event_page_seen = False + + for page in content['results']: + self.assertIn(page['meta']['type'], ['demosite.BlogEntryPage', 'demosite.EventPage']) + + if page['meta']['type'] == 'demosite.BlogEntryPage': + blog_page_seen = True + elif page['meta']['type'] == 'demosite.EventPage': + event_page_seen = True + + # Only generic fields available + self.assertEqual(set(page.keys()), {'id', 'meta', 'title'}) + + self.assertTrue(blog_page_seen, "No blog pages were found in the results") + self.assertTrue(event_page_seen, "No event pages were found in the results") + def test_non_existant_type_gives_error(self): response = self.get_response(type='demosite.IDontExist') content = json.loads(response.content.decode('UTF-8')) diff --git a/wagtail/api/v2/utils.py b/wagtail/api/v2/utils.py index 629b1dc643..3d2fc3d93e 100644 --- a/wagtail/api/v2/utils.py +++ b/wagtail/api/v2/utils.py @@ -2,6 +2,7 @@ from django.conf import settings from django.utils.six.moves.urllib.parse import urlparse from wagtail.wagtailcore.models import Page +from wagtail.wagtailcore.utils import resolve_model_string class BadRequestError(Exception): @@ -27,3 +28,26 @@ def pages_for_site(site): pages = Page.objects.public().live() pages = pages.descendant_of(site.root_page, inclusive=True) return pages + + +def page_models_from_string(string): + page_models = [] + + for sub_string in string.split(','): + page_model = resolve_model_string(sub_string) + + if not issubclass(page_model, Page): + raise ValueError("Model is not a page") + + page_models.append(page_model) + + return tuple(page_models) + + +def filter_page_type(queryset, page_models): + qs = queryset.none() + + for model in page_models: + qs |= queryset.type(model) + + return qs