Refactor get_base_queryset to no longer take models as a parameter

pull/5677/head
Karl Hobley 2019-11-28 10:24:25 +00:00
rodzic 6af0b20152
commit f0de04f0bd
2 zmienionych plików z 12 dodań i 28 usunięć

Wyświetl plik

@ -2,7 +2,6 @@ from collections import OrderedDict
from rest_framework.authentication import SessionAuthentication
from wagtail.api.v2.utils import filter_page_type
from wagtail.api.v2.views import PagesAPIViewSet
from wagtail.core.models import Page
@ -54,25 +53,14 @@ class PagesAdminAPIViewSet(PagesAPIViewSet):
"""
return Page.get_first_root_node()
def get_base_queryset(self, models=None):
def get_base_queryset(self):
"""
Returns a queryset containing all pages that can be seen by this user.
This is used as the base for get_queryset and is also used to find the
parent pages when using the child_of and descendant_of filters as well.
"""
if models is None:
models = [Page]
if len(models) == 1:
queryset = models[0].objects.all()
else:
queryset = Page.objects.all()
# Filter pages by specified models
queryset = filter_page_type(queryset, models)
return queryset
return Page.objects.all()
def get_queryset(self):
queryset = super().get_queryset()

Wyświetl plik

@ -405,26 +405,15 @@ class PagesAPIViewSet(BaseAPIViewSet):
"""
return self.request.site.root_page
def get_base_queryset(self, models=None):
def get_base_queryset(self):
"""
Returns a queryset containing all pages that can be seen by this user.
This is used as the base for get_queryset and is also used to find the
parent pages when using the child_of and descendant_of filters as well.
"""
if models is None:
models = [Page]
if len(models) == 1:
queryset = models[0].objects.all()
else:
queryset = Page.objects.all()
# Filter pages by specified models
queryset = filter_page_type(queryset, models)
# Get live pages that are not in a private section
queryset = queryset.public().live()
queryset = Page.objects.all().public().live()
# Filter by site
if self.request.site:
@ -444,7 +433,14 @@ class PagesAPIViewSet(BaseAPIViewSet):
except (LookupError, ValueError):
raise BadRequestError("type doesn't exist")
return self.get_base_queryset(models)
if not models:
return self.get_base_queryset()
elif len(models) == 1:
return models[0].objects.filter(id__in=self.get_base_queryset().values_list('id', flat=True))
else: # len(models) > 1
return filter_page_type(self.get_base_queryset(), models)
def get_object(self):
base = super().get_object()