diff --git a/wagtail/api/v2/filters.py b/wagtail/api/v2/filters.py index 5322804dc4..faf1b4aea4 100644 --- a/wagtail/api/v2/filters.py +++ b/wagtail/api/v2/filters.py @@ -1,9 +1,10 @@ from django.conf import settings from django.db import models +from django.shortcuts import get_object_or_404 from rest_framework.filters import BaseFilterBackend from taggit.managers import TaggableManager -from wagtail.core.models import Page +from wagtail.core.models import Locale, Page from wagtail.search.backends import get_search_backend from wagtail.search.backends.base import FilterFieldError, OrderByFieldError @@ -18,6 +19,10 @@ class FieldsFilter(BaseFilterBackend): """ fields = set(view.get_available_fields(queryset.model, db_fields_only=True)) + # Locale is a database field, but we provide a separate filter for it + if 'locale' in fields: + fields.remove('locale') + for field_name, value in request.GET.items(): if field_name in fields: try: @@ -180,3 +185,16 @@ class DescendantOfFilter(BaseFilterBackend): queryset = queryset.descendant_of(parent_page) return queryset + + +class LocaleFilter(BaseFilterBackend): + """ + Implements the ?locale filter which limits the set of pages to a + particular locale. + """ + def filter_queryset(self, request, queryset, view): + if 'locale' in request.GET: + locale = get_object_or_404(Locale, language_code=request.GET['locale']) + queryset = queryset.filter(locale=locale) + + return queryset diff --git a/wagtail/api/v2/tests/test_pages.py b/wagtail/api/v2/tests/test_pages.py index 1c9b900e42..ec938ede98 100644 --- a/wagtail/api/v2/tests/test_pages.py +++ b/wagtail/api/v2/tests/test_pages.py @@ -8,7 +8,7 @@ from django.test.utils import override_settings from django.urls import reverse from wagtail.api.v2 import signal_handlers -from wagtail.core.models import Page, Site +from wagtail.core.models import Locale, Page, Site from wagtail.tests.demosite import models from wagtail.tests.testapp.models import StreamPage @@ -145,6 +145,21 @@ class TestPageListing(TestCase): self.assertEqual(response.status_code, 400) self.assertEqual(content, {'message': "type doesn't exist"}) + # LOCALE FILTER + + @override_settings(WAGTAIL_I18N_ENABLED=True) + def test_locale_filter(self): + french = Locale.objects.create(language_code='fr') + homepage = Page.objects.get(depth=2) + french_homepage = homepage.copy_for_translation(french) + french_homepage.get_latest_revision().publish() + + response = self.get_response(locale='fr') + content = json.loads(response.content.decode('UTF-8')) + + self.assertEqual(len(content['items']), 1) + self.assertEqual(content['items'][0]['id'], french_homepage.id) + # FIELDS def test_fields_default(self): diff --git a/wagtail/api/v2/views.py b/wagtail/api/v2/views.py index f5f830d71d..b877d838d0 100644 --- a/wagtail/api/v2/views.py +++ b/wagtail/api/v2/views.py @@ -14,7 +14,8 @@ from rest_framework.viewsets import GenericViewSet from wagtail.api import APIField from wagtail.core.models import Page, Site -from .filters import ChildOfFilter, DescendantOfFilter, FieldsFilter, OrderingFilter, SearchFilter +from .filters import ( + ChildOfFilter, DescendantOfFilter, FieldsFilter, LocaleFilter, OrderingFilter, SearchFilter) from .pagination import WagtailPagination from .serializers import BaseSerializer, PageSerializer, get_serializer_class from .utils import ( @@ -367,12 +368,14 @@ class PagesAPIViewSet(BaseAPIViewSet): ChildOfFilter, DescendantOfFilter, OrderingFilter, - SearchFilter + SearchFilter, + LocaleFilter, ] known_query_parameters = BaseAPIViewSet.known_query_parameters.union([ 'type', 'child_of', 'descendant_of', + 'locale', ]) body_fields = BaseAPIViewSet.body_fields + [ 'title',