Implement `locale` filter in API

pull/6414/head
Karl Hobley 2020-09-17 16:50:45 +01:00 zatwierdzone przez Karl Hobley
rodzic eee9fbdaa9
commit cd44515c33
3 zmienionych plików z 40 dodań i 4 usunięć

Wyświetl plik

@ -1,9 +1,10 @@
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
from django.shortcuts import get_object_or_404
from rest_framework.filters import BaseFilterBackend from rest_framework.filters import BaseFilterBackend
from taggit.managers import TaggableManager from taggit.managers import TaggableManager
from wagtail.core.models import Page from wagtail.core.models import Locale, Page
from wagtail.search.backends import get_search_backend from wagtail.search.backends import get_search_backend
from wagtail.search.backends.base import FilterFieldError, OrderByFieldError from wagtail.search.backends.base import FilterFieldError, OrderByFieldError
@ -18,6 +19,10 @@ class FieldsFilter(BaseFilterBackend):
""" """
fields = set(view.get_available_fields(queryset.model, db_fields_only=True)) fields = set(view.get_available_fields(queryset.model, db_fields_only=True))
# Locale is a database field, but we provide a separate filter for it
if 'locale' in fields:
fields.remove('locale')
for field_name, value in request.GET.items(): for field_name, value in request.GET.items():
if field_name in fields: if field_name in fields:
try: try:
@ -180,3 +185,16 @@ class DescendantOfFilter(BaseFilterBackend):
queryset = queryset.descendant_of(parent_page) queryset = queryset.descendant_of(parent_page)
return queryset return queryset
class LocaleFilter(BaseFilterBackend):
"""
Implements the ?locale filter which limits the set of pages to a
particular locale.
"""
def filter_queryset(self, request, queryset, view):
if 'locale' in request.GET:
locale = get_object_or_404(Locale, language_code=request.GET['locale'])
queryset = queryset.filter(locale=locale)
return queryset

Wyświetl plik

@ -8,7 +8,7 @@ from django.test.utils import override_settings
from django.urls import reverse from django.urls import reverse
from wagtail.api.v2 import signal_handlers from wagtail.api.v2 import signal_handlers
from wagtail.core.models import Page, Site from wagtail.core.models import Locale, Page, Site
from wagtail.tests.demosite import models from wagtail.tests.demosite import models
from wagtail.tests.testapp.models import StreamPage from wagtail.tests.testapp.models import StreamPage
@ -145,6 +145,21 @@ class TestPageListing(TestCase):
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(content, {'message': "type doesn't exist"}) self.assertEqual(content, {'message': "type doesn't exist"})
# LOCALE FILTER
@override_settings(WAGTAIL_I18N_ENABLED=True)
def test_locale_filter(self):
french = Locale.objects.create(language_code='fr')
homepage = Page.objects.get(depth=2)
french_homepage = homepage.copy_for_translation(french)
french_homepage.get_latest_revision().publish()
response = self.get_response(locale='fr')
content = json.loads(response.content.decode('UTF-8'))
self.assertEqual(len(content['items']), 1)
self.assertEqual(content['items'][0]['id'], french_homepage.id)
# FIELDS # FIELDS
def test_fields_default(self): def test_fields_default(self):

Wyświetl plik

@ -14,7 +14,8 @@ from rest_framework.viewsets import GenericViewSet
from wagtail.api import APIField from wagtail.api import APIField
from wagtail.core.models import Page, Site from wagtail.core.models import Page, Site
from .filters import ChildOfFilter, DescendantOfFilter, FieldsFilter, OrderingFilter, SearchFilter from .filters import (
ChildOfFilter, DescendantOfFilter, FieldsFilter, LocaleFilter, OrderingFilter, SearchFilter)
from .pagination import WagtailPagination from .pagination import WagtailPagination
from .serializers import BaseSerializer, PageSerializer, get_serializer_class from .serializers import BaseSerializer, PageSerializer, get_serializer_class
from .utils import ( from .utils import (
@ -367,12 +368,14 @@ class PagesAPIViewSet(BaseAPIViewSet):
ChildOfFilter, ChildOfFilter,
DescendantOfFilter, DescendantOfFilter,
OrderingFilter, OrderingFilter,
SearchFilter SearchFilter,
LocaleFilter,
] ]
known_query_parameters = BaseAPIViewSet.known_query_parameters.union([ known_query_parameters = BaseAPIViewSet.known_query_parameters.union([
'type', 'type',
'child_of', 'child_of',
'descendant_of', 'descendant_of',
'locale',
]) ])
body_fields = BaseAPIViewSet.body_fields + [ body_fields = BaseAPIViewSet.body_fields + [
'title', 'title',