diff --git a/wagtail/tests/models.py b/wagtail/tests/models.py index e2a695f034..41a0068a1c 100644 --- a/wagtail/tests/models.py +++ b/wagtail/tests/models.py @@ -429,14 +429,24 @@ class SearchTest(models.Model, index.Indexed): def get_indexed_objects(cls): indexed_objects = super(SearchTest, cls).get_indexed_objects() - # Exclude SearchTests that have a SearchTestChild to prevent duplicates + # Exclude SearchTests that have a SearchTestChild to stop update_index creating duplicates if cls is SearchTest: indexed_objects = indexed_objects.exclude( id__in=SearchTestChild.objects.all().values_list('searchtest_ptr_id', flat=True) ) + # Exclude SearchTests that have the title "Don't index me!" + indexed_objects = indexed_objects.exclude(title="Don't index me!") + return indexed_objects + def get_indexed_instance(self): + # Check if there is a SearchTestChild that descends from this + child = SearchTestChild.objects.filter(searchtest_ptr_id=self.id).first() + + # Return the child if there is one, otherwise return self + return child or self + class SearchTestChild(SearchTest): subtitle = models.CharField(max_length=255, null=True, blank=True) extra_content = models.TextField() diff --git a/wagtail/wagtailcore/models.py b/wagtail/wagtailcore/models.py index c789d872ae..cf759b05a5 100644 --- a/wagtail/wagtailcore/models.py +++ b/wagtail/wagtailcore/models.py @@ -564,6 +564,9 @@ class Page(six.with_metaclass(PageBase, MP_Node, ClusterableModel, index.Indexed content_type = ContentType.objects.get_for_model(cls) return super(Page, cls).get_indexed_objects().filter(content_type=content_type) + def get_indexed_instance(self): + return self.specific + @classmethod def search(cls, query_string, show_unpublished=False, search_title_only=False, extra_filters={}, prefetch_related=[], path=None): # Filters diff --git a/wagtail/wagtailsearch/index.py b/wagtail/wagtailsearch/index.py index 82790c5804..70a4ef7644 100644 --- a/wagtail/wagtailsearch/index.py +++ b/wagtail/wagtailsearch/index.py @@ -51,6 +51,13 @@ class Indexed(object): def get_indexed_objects(cls): return cls.objects.all() + def get_indexed_instance(self): + """ + If the indexed model uses multi table inheritance, override this method + to return the instance in its most specific class so it reindexes properly. + """ + return self + search_fields = () diff --git a/wagtail/wagtailsearch/signal_handlers.py b/wagtail/wagtailsearch/signal_handlers.py index a828c64c26..e9ce54587d 100644 --- a/wagtail/wagtailsearch/signal_handlers.py +++ b/wagtail/wagtailsearch/signal_handlers.py @@ -5,21 +5,30 @@ from wagtail.wagtailsearch.index import Indexed from wagtail.wagtailsearch.backends import get_search_backends -def post_save_signal_handler(instance, **kwargs): - if not type(instance).get_indexed_objects().filter(id=instance.id).exists(): +def get_indexed_instance(instance): + indexed_instance = instance.get_indexed_instance() + + # Make sure that the instance is in it's classes indexed objects + if not type(indexed_instance).get_indexed_objects().filter(id=indexed_instance.id).exists(): return + return indexed_instance - for backend in get_search_backends(): - backend.add(instance) + +def post_save_signal_handler(instance, **kwargs): + indexed_instance = get_indexed_instance(instance) + + if indexed_instance: + for backend in get_search_backends(): + backend.add(indexed_instance) def post_delete_signal_handler(instance, **kwargs): - if not type(instance).get_indexed_objects().filter(id=instance.id).exists(): - return + indexed_instance = get_indexed_instance(instance) - for backend in get_search_backends(): - backend.delete(instance) + if indexed_instance: + for backend in get_search_backends(): + backend.delete(indexed_instance) def register_signal_handlers(): diff --git a/wagtail/wagtailsearch/tests/test_signal_handlers.py b/wagtail/wagtailsearch/tests/test_signal_handlers.py new file mode 100644 index 0000000000..ffdd143cf9 --- /dev/null +++ b/wagtail/wagtailsearch/tests/test_signal_handlers.py @@ -0,0 +1,40 @@ +from django.test import TestCase + +from wagtail.wagtailsearch import signal_handlers +from wagtail.tests import models + + +class TestGetIndexedInstance(TestCase): + def test_gets_instance(self): + obj = models.SearchTest( + title="Hello", + live=True, + ) + obj.save() + + # Should just return the object + indexed_instance = signal_handlers.get_indexed_instance(obj) + self.assertEqual(indexed_instance, obj) + + def test_gets_specific_class(self): + obj = models.SearchTestChild( + title="Hello", + live=True, + ) + obj.save() + + # Running the command with the parent class should find the specific class again + indexed_instance = signal_handlers.get_indexed_instance(obj.searchtest_ptr) + self.assertEqual(indexed_instance, obj) + + def test_blocks_not_in_indexed_objects(self): + obj = models.SearchTestChild( + title="Don't index me!", + live=True, + ) + obj.save() + + # We've told it not to index anything with the title "Don't index me" + # get_indexed_instance should return None + indexed_instance = signal_handlers.get_indexed_instance(obj.searchtest_ptr) + self.assertEqual(indexed_instance, None)