From 1b691a905f5baa0b6fceba8e313ca4ec78ad7ecf Mon Sep 17 00:00:00 2001 From: Sage Abdullah Date: Tue, 27 Aug 2024 17:18:54 +0100 Subject: [PATCH] Fix RevisionQuerySet.for_instance() when used with a non-specific instance If the method was used with a base Page instance, it wouldn't return any revisions because we would be filtering on the content_type FK using the current model (the base Page model) instead of the specific model. Filter on base_content_type and make use of RevisionMixin.get_base_content_type() instead, which will resolve to the base Page model (and the correct the most basic non-abstract model for non-Page models with MTI). Use the old logic if the instance's model does not use RevisionMixin for some reason. This logic is similar to WorkflowStateQuerySet.for_instance() and TaskQuerySet.for_instance(). --- wagtail/models/__init__.py | 20 ++++++++++++++------ wagtail/tests/test_revision_model.py | 27 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/wagtail/models/__init__.py b/wagtail/models/__init__.py index 93a7388717..7d451d58ce 100644 --- a/wagtail/models/__init__.py +++ b/wagtail/models/__init__.py @@ -2792,12 +2792,20 @@ class RevisionQuerySet(models.QuerySet): return self.exclude(self.page_revisions_q()) def for_instance(self, instance): - return self.filter( - content_type=ContentType.objects.get_for_model( - instance, for_concrete_model=False - ), - object_id=str(instance.pk), - ) + try: + # Use RevisionMixin.get_base_content_type() if available + return self.filter( + base_content_type=instance.get_base_content_type(), + object_id=str(instance.pk), + ) + except AttributeError: + # Fallback to ContentType for the model + return self.filter( + content_type=ContentType.objects.get_for_model( + instance, for_concrete_model=False + ), + object_id=str(instance.pk), + ) class RevisionsManager(models.Manager.from_queryset(RevisionQuerySet)): diff --git a/wagtail/tests/test_revision_model.py b/wagtail/tests/test_revision_model.py index 74b964a95f..36185d1b50 100644 --- a/wagtail/tests/test_revision_model.py +++ b/wagtail/tests/test_revision_model.py @@ -70,6 +70,9 @@ class TestRevisableModel(TestCase): self.assertEqual(self.instance.get_base_content_type(), self.content_type) self.assertEqual(self.instance.get_content_type(), self.content_type) + # The for_instance() method should return the revision + self.assertEqual(Revision.objects.for_instance(self.instance).first(), revision) + def test_content_type_with_inheritance(self): instance = RevisableGrandChildModel.objects.create(text="test") instance.text = "test updated" @@ -87,6 +90,18 @@ class TestRevisableModel(TestCase): self.assertEqual(instance.get_base_content_type(), base_content_type) self.assertEqual(instance.get_content_type(), content_type) + # The for_instance() method should return the revision, + # whether we're using the specific instance + self.assertIsInstance(instance, RevisableModel) + self.assertIsInstance(instance, RevisableGrandChildModel) + self.assertEqual(Revision.objects.for_instance(instance).first(), revision) + + # or the base instance + base_instance = RevisableModel.objects.get(pk=instance.pk) + self.assertIsInstance(base_instance, RevisableModel) + self.assertNotIsInstance(base_instance, RevisableGrandChildModel) + self.assertEqual(Revision.objects.for_instance(base_instance).first(), revision) + def test_content_type_for_page_model(self): hello_page = self.create_page() hello_page.content = "Updated world" @@ -104,6 +119,18 @@ class TestRevisableModel(TestCase): self.assertEqual(hello_page.get_base_content_type(), base_content_type) self.assertEqual(hello_page.get_content_type(), content_type) + # The for_instance() method should return the revision, + # whether we're using the specific instance + self.assertIsInstance(hello_page, SimplePage) + self.assertIsInstance(hello_page, Page) + self.assertEqual(Revision.objects.for_instance(hello_page).first(), revision) + + # or the base instance + base_instance = Page.objects.get(pk=hello_page.pk) + self.assertIsInstance(base_instance, Page) + self.assertNotIsInstance(base_instance, SimplePage) + self.assertEqual(Revision.objects.for_instance(base_instance).first(), revision) + def test_as_object(self): self.instance.text = "updated" self.instance.save_revision()