From 0fcdd08bf00f57c7ddf753103f590d9c1aab6674 Mon Sep 17 00:00:00 2001 From: Sage Abdullah Date: Fri, 6 Oct 2023 19:32:39 +0100 Subject: [PATCH] Fix PageQuerySet.prefetch_workflow_states when used with .specific() --- .../admin/tests/pages/test_explorer_view.py | 42 ++++++++++++- wagtail/models/__init__.py | 14 +++++ wagtail/query.py | 16 ++++- wagtail/tests/test_page_queryset.py | 60 ++++++++++++++++++- 4 files changed, 129 insertions(+), 3 deletions(-) diff --git a/wagtail/admin/tests/pages/test_explorer_view.py b/wagtail/admin/tests/pages/test_explorer_view.py index 949fd87e08..592f44f5da 100644 --- a/wagtail/admin/tests/pages/test_explorer_view.py +++ b/wagtail/admin/tests/pages/test_explorer_view.py @@ -5,7 +5,7 @@ from django.test import TestCase, override_settings from django.urls import reverse from wagtail import hooks -from wagtail.models import GroupPagePermission, Locale, Page +from wagtail.models import GroupPagePermission, Locale, Page, Workflow from wagtail.test.testapp.models import SimplePage, SingleEventPage, StandardIndex from wagtail.test.utils import WagtailTestUtils from wagtail.test.utils.timestamps import local_datetime @@ -769,3 +769,43 @@ class TestLocaleSelector(WagtailTestUtils, TestCase): allow_extra_attrs=True, count=0, ) + + +class TestInWorkflowStatus(WagtailTestUtils, TestCase): + fixtures = ["test.json"] + + @classmethod + def setUpTestData(cls): + cls.event_index = Page.objects.get(url_path="/home/events/") + cls.christmas = Page.objects.get(url_path="/home/events/christmas/").specific + cls.saint_patrick = Page.objects.get( + url_path="/home/events/saint-patrick/" + ).specific + cls.christmas.save_revision() + cls.saint_patrick.save_revision() + cls.url = reverse("wagtailadmin_explore", args=[cls.event_index.pk]) + + def setUp(self): + self.user = self.login() + + def test_in_workflow_status(self): + workflow = Workflow.objects.first() + workflow.start(self.christmas, self.user) + workflow.start(self.saint_patrick, self.user) + + # Warm up cache + self.client.get(self.url) + + with self.assertNumQueries(50): + response = self.client.get(self.url) + + self.assertEqual(response.status_code, 200) + soup = self.get_soup(response.content) + + for page in [self.christmas, self.saint_patrick]: + status = soup.select_one(f'a.w-status[href="{page.url}"]') + self.assertIsNotNone(status) + self.assertEqual( + status.text.strip(), "Current page status: live + in moderation" + ) + self.assertEqual(page.status_string, "live + in moderation") diff --git a/wagtail/models/__init__.py b/wagtail/models/__init__.py index 65009ce95e..be80dea6ce 100644 --- a/wagtail/models/__init__.py +++ b/wagtail/models/__init__.py @@ -1195,6 +1195,20 @@ class Page(AbstractPage, index.Indexed, ClusterableModel, metaclass=PageBase): for_concrete_model=False, ) + # When using a specific queryset, accessing the _workflow_states GenericRelation + # will yield no results. This is because the _workflow_states GenericRelation + # uses the base_content_type as the content_type_field, which is not the same + # as the content type of the specific queryset. To work around this, we define + # a second GenericRelation that uses the specific content_type to be used + # when working with specific querysets. + _specific_workflow_states = GenericRelation( + "wagtailcore.WorkflowState", + content_type_field="content_type", + object_id_field="object_id", + related_query_name="page", + for_concrete_model=False, + ) + # If non-null, this page is an alias of the linked page # This means the page is kept in sync with the live version # of the linked pages and is not editable by users. diff --git a/wagtail/query.py b/wagtail/query.py index 9d4255ee77..0f3b48c56e 100644 --- a/wagtail/query.py +++ b/wagtail/query.py @@ -167,6 +167,16 @@ class SpecificQuerySetMixin: clone._iterable_class = SpecificIterable return clone + @property + def is_specific(self): + """ + Returns True if this queryset is already specific, False otherwise. + """ + return issubclass( + self._iterable_class, + (SpecificIterable, DeferredSpecificIterable), + ) + class PageQuerySet(SearchableQuerySetMixin, SpecificQuerySetMixin, TreeQuerySet): def live_q(self): @@ -455,9 +465,13 @@ class PageQuerySet(SearchableQuerySetMixin, SpecificQuerySetMixin, TreeQuerySet) "current_task_state__task" ) + relation = "_workflow_states" + if self.is_specific: + relation = "_specific_workflow_states" + return self.prefetch_related( Prefetch( - "_workflow_states", + relation, queryset=workflow_states, to_attr="_current_workflow_states", ) diff --git a/wagtail/tests/test_page_queryset.py b/wagtail/tests/test_page_queryset.py index 34ceafae81..6071419f06 100644 --- a/wagtail/tests/test_page_queryset.py +++ b/wagtail/tests/test_page_queryset.py @@ -1,12 +1,13 @@ from io import StringIO from unittest import mock +from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.core import management from django.db.models import Count, Q from django.test import TestCase, TransactionTestCase -from wagtail.models import Locale, Page, PageViewRestriction, Site +from wagtail.models import Locale, Page, PageViewRestriction, Site, Workflow from wagtail.search.query import MATCH_ALL from wagtail.signals import page_unpublished from wagtail.test.testapp.models import ( @@ -589,6 +590,63 @@ class TestPageQuerySet(TestCase): else: self.assertIn(page, translations) + def test_prefetch_workflow_states(self): + home = Page.objects.get(url_path="/home/") + event_index = Page.objects.get(url_path="/home/events/") + user = get_user_model().objects.first() + workflow = Workflow.objects.first() + + test_pages = [home.specific, event_index.specific] + workflow_states = {} + current_tasks = {} + + for page in test_pages: + page.save_revision() + approved_workflow_state = workflow.start(page, user) + task_state = approved_workflow_state.current_task_state + task_state.task.on_action(task_state, user=None, action_name="approve") + + workflow_state = workflow.start(page, user) + + # Refresh so that the current_task_state.task is not the specific instance + workflow_state.refresh_from_db() + + workflow_states[page.pk] = workflow_state + current_tasks[page.pk] = workflow_state.current_task_state.task + + query = Page.objects.filter(pk__in=(home.pk, event_index.pk)) + queries = [["base", query, 2], ["specific", query.specific(), 4]] + + for case, query, num_queries in queries: + with self.subTest(case=case): + with self.assertNumQueries(num_queries): + queried_pages = { + page.pk: page for page in query.prefetch_workflow_states() + } + + for test_page in test_pages: + page = queried_pages[test_page.pk] + with self.assertNumQueries(0): + self.assertEqual( + page._current_workflow_states, + [workflow_states[page.pk]], + ) + + with self.assertNumQueries(0): + self.assertEqual( + page._current_workflow_states[0].current_task_state.task, + current_tasks[page.pk], + ) + + with self.assertNumQueries(0): + self.assertTrue(page.workflow_in_progress) + + with self.assertNumQueries(0): + self.assertTrue( + page.current_workflow_state, + workflow_states[page.pk], + ) + class TestPageQueryInSite(TestCase): fixtures = ["test.json"]