diff --git a/wagtail/wagtailsearch/backends/__init__.py b/wagtail/wagtailsearch/backends/__init__.py index ee037497b5..b1b991066b 100644 --- a/wagtail/wagtailsearch/backends/__init__.py +++ b/wagtail/wagtailsearch/backends/__init__.py @@ -4,7 +4,9 @@ from importlib import import_module from django.utils import six +import sys from django.conf import settings +from base import InvalidSearchBackendError # Pinched from django 1.7 source code. @@ -45,8 +47,13 @@ def get_search_backend(backend='default', **kwargs): # Try to get the WAGTAILSEARCH_BACKENDS entry for the given backend name first conf = WAGTAILSEARCH_BACKENDS[backend] except KeyError: - raise InvalidSearchBackendError("Could not find backend '%s': %s" % ( - backend, e)) + try: + # Trying to import the given backend, in case it's a dotted path + import_string(backend) + except ImportError as e: + raise InvalidSearchBackendError("Could not find backend '%s': %s" % ( + backend, e)) + params = kwargs else: # Backend is a conf entry params = conf.copy() diff --git a/wagtail/wagtailsearch/backends/db.py b/wagtail/wagtailsearch/backends/db.py index ad16b082bf..419d8a99be 100644 --- a/wagtail/wagtailsearch/backends/db.py +++ b/wagtail/wagtailsearch/backends/db.py @@ -47,6 +47,13 @@ class DBSearch(BaseSearch): for term in terms: term_query = None for field_name in fields: + # Check if the field exists (this will filter out indexed callables) + try: + model._meta.get_field_by_name(field_name) + except: + continue + + # Filter on this field field_filter = {'%s__icontains' % field_name: term} if term_query is None: term_query = models.Q(**field_filter) diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index 56fa2b04dd..00c1bda0e9 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -178,9 +178,11 @@ class ElasticSearch(BaseSearch): type_set[obj_type].append(obj.indexed_build_document()) # Loop through each type and bulk add them + results = [] for type_name, type_objects in type_set.items(): - print type_name, len(type_objects) + results.append((type_name, len(type_objects))) self.es.bulk_index(self.es_index, type_name, type_objects) + return results def delete(self, obj): # Object must be a decendant of Indexed and be a django model diff --git a/wagtail/wagtailsearch/management/commands/search_garbage_collect.py b/wagtail/wagtailsearch/management/commands/search_garbage_collect.py index ae37534b4b..08d8d8c1b5 100644 --- a/wagtail/wagtailsearch/management/commands/search_garbage_collect.py +++ b/wagtail/wagtailsearch/management/commands/search_garbage_collect.py @@ -6,11 +6,11 @@ from wagtail.wagtailsearch import models class Command(NoArgsCommand): def handle_noargs(self, **options): # Clean daily hits - print "Cleaning daily hits records... ", + self.stdout.write("Cleaning daily hits records... ") models.QueryDailyHits.garbage_collect() - print "Done" + self.stdout.write("Done") # Clean queries - print "Cleaning query records... ", + self.stdout.write("Cleaning query records... ") models.Query.garbage_collect() - print "Done" + self.stdout.write("Done") diff --git a/wagtail/wagtailsearch/management/commands/update_index.py b/wagtail/wagtailsearch/management/commands/update_index.py index 7e901bf813..522e206236 100644 --- a/wagtail/wagtailsearch/management/commands/update_index.py +++ b/wagtail/wagtailsearch/management/commands/update_index.py @@ -1,13 +1,13 @@ -from django.core.management.base import NoArgsCommand +from django.core.management.base import BaseCommand from django.db import models from wagtail.wagtailsearch import Indexed, get_search_backend -class Command(NoArgsCommand): - def handle_noargs(self, **options): +class Command(BaseCommand): + def handle(self, backend='default', **options): # Print info - print "Getting object list" + self.stdout.write("Getting object list") # Get list of indexed models indexed_models = [model for model in models.get_models() if issubclass(model, Indexed)] @@ -46,21 +46,24 @@ class Command(NoArgsCommand): object_set[key] = obj # Search backend - s = get_search_backend() + s = get_search_backend(backend=backend) # Reset the index - print "Reseting index" + self.stdout.write("Reseting index") s.reset_index() # Add types - print "Adding types" + self.stdout.write("Adding types") for model in indexed_models: s.add_type(model) # Add objects to index - print "Adding objects" - s.add_bulk(object_set.values()) + self.stdout.write("Adding objects") + results = s.add_bulk(object_set.values()) + if results: + for result in results: + self.stdout.write(result[0] + ' ' + str(result[1])) # Refresh index - print "Refreshing index" + self.stdout.write("Refreshing index") s.refresh_index() diff --git a/wagtail/wagtailsearch/models.py b/wagtail/wagtailsearch/models.py index 79aa38f511..8527cfe579 100644 --- a/wagtail/wagtailsearch/models.py +++ b/wagtail/wagtailsearch/models.py @@ -16,11 +16,16 @@ class Query(models.Model): super(Query, self).save(*args, **kwargs) - def add_hit(self): - daily_hits, created = QueryDailyHits.objects.get_or_create(query=self, date=timezone.now().date()) + def add_hit(self, date=None): + if date is None: + date = timezone.now().date() + daily_hits, created = QueryDailyHits.objects.get_or_create(query=self, date=date) daily_hits.hits = models.F('hits') + 1 daily_hits.save() + def __unicode__(self): + return self.query_string + @property def hits(self): return self.daily_hits.aggregate(models.Sum('hits'))['hits__sum'] @@ -38,6 +43,7 @@ class Query(models.Model): @classmethod def get_most_popular(cls, date_since=None): + # TODO: Implement date_since return cls.objects.filter(daily_hits__isnull=False).annotate(_hits=models.Sum('daily_hits__hits')).distinct().order_by('-_hits') @staticmethod @@ -49,7 +55,7 @@ class Query(models.Model): query_string = ''.join([c for c in query_string if c not in string.punctuation]) # Remove double spaces - ' '.join(query_string.split()) + query_string = ' '.join(query_string.split()) return query_string @@ -90,10 +96,18 @@ class SearchTest(models.Model, Indexed): title = models.CharField(max_length=255) content = models.TextField() - indexed_fields = ("title", "content") + indexed_fields = ("title", "content", "callable_indexed_field") title_search = Searcher(["title"]) + def object_indexed(self): + if self.title == "Don't index me!": + return False + return True + + def callable_indexed_field(self): + return "Callable" + class SearchTestChild(SearchTest): extra_content = models.TextField() diff --git a/wagtail/wagtailsearch/searcher.py b/wagtail/wagtailsearch/searcher.py index 24510c257b..0861390fcd 100644 --- a/wagtail/wagtailsearch/searcher.py +++ b/wagtail/wagtailsearch/searcher.py @@ -8,7 +8,17 @@ class Searcher(object): def __get__(self, instance, cls): def dosearch(query_string, **kwargs): + # Get backend + if 'backend' in kwargs: + backend = kwargs['backend'] + del kwargs['backend'] + else: + backend = 'default' + + # Build search kwargs search_kwargs = dict(model=cls, fields=self.fields, filters=self.filters) search_kwargs.update(kwargs) - return get_search_backend().search(query_string, **search_kwargs) + + # Run search + return get_search_backend(backend=backend).search(query_string, **search_kwargs) return dosearch diff --git a/wagtail/wagtailsearch/tests.py b/wagtail/wagtailsearch/tests.py index ca6e71994a..6849d3907f 100644 --- a/wagtail/wagtailsearch/tests.py +++ b/wagtail/wagtailsearch/tests.py @@ -1,13 +1,56 @@ from django.test import TestCase +from django.utils import timezone +from django.core import management +from django.conf import settings -import models -from wagtail.wagtailsearch.backends import get_search_backend() +import datetime +import unittest +from StringIO import StringIO + +from wagtail.wagtailcore import models as core_models +from wagtail.wagtailsearch import models +from wagtail.wagtailsearch.backends import get_search_backend + +from wagtail.wagtailsearch.backends.base import InvalidSearchBackendError +from wagtail.wagtailsearch.backends.db import DBSearch +from wagtail.wagtailsearch.backends.elasticsearch import ElasticSearch + + +def find_backend(cls): + if not hasattr(settings, 'WAGTAILSEARCH_BACKENDS') and cls == DBSearch: + return 'default' + + for backend in settings.WAGTAILSEARCH_BACKENDS.keys(): + if isinstance(get_search_backend(backend), cls): + return backend class TestSearch(TestCase): - def test_search(self): - # Create search backend and reset the index - s = get_search_backend() + def __init__(self, *args, **kwargs): + super(TestSearch, self).__init__(*args, **kwargs) + + self.backends_tested = [] + + def test_backend_loader(self): + # Test DB backend import + db = get_search_backend(backend='wagtail.wagtailsearch.backends.db.DBSearch') + self.assertIsInstance(db, DBSearch) + + # Test Elastic search backend import + elasticsearch = get_search_backend(backend='wagtail.wagtailsearch.backends.elasticsearch.ElasticSearch') + self.assertIsInstance(elasticsearch, ElasticSearch) + + # Test loading a non existant backend + self.assertRaises(InvalidSearchBackendError, get_search_backend, backend='wagtail.wagtailsearch.backends.doesntexist.DoesntExist') + + def test_search(self, backend='default'): + # Don't test the same backend more than once! + if backend in self.backends_tested: + return + self.backends_tested.append(backend) + + # Get search backend and reset the index + s = get_search_backend(backend=backend) s.reset_index() # Create a couple of objects and add them to the index @@ -33,12 +76,29 @@ class TestSearch(TestCase): results = s.search("Hello", models.SearchTest) self.assertEqual(len(results), 3) - # Ordinary search on "World" + # Retrieve single result + self.assertIsInstance(results[0], models.SearchTest) + + # Retrieve results through iteration + iterations = 0 + for result in results: + self.assertIsInstance(result, models.SearchTest) + iterations += 1 + self.assertEqual(iterations, 3) + + # Retrieve results through slice + iterations = 0 + for result in results[:]: + self.assertIsInstance(result, models.SearchTest) + iterations += 1 + self.assertEqual(iterations, 3) + + # Ordinary search on "World" results = s.search("World", models.SearchTest) self.assertEqual(len(results), 1) # Searcher search - results = models.SearchTest.title_search("Hello") + results = models.SearchTest.title_search("Hello", backend=backend) self.assertEqual(len(results), 3) # Ordinary search on child @@ -46,5 +106,178 @@ class TestSearch(TestCase): self.assertEqual(len(results), 1) # Searcher search on child - results = models.SearchTestChild.title_search("Hello") + results = models.SearchTestChild.title_search("Hello", backend=backend) self.assertEqual(len(results), 1) + + # Reset the index, this should clear out the index (but doesn't have to!) + s.reset_index() + + # Run update_index command + management.call_command('update_index', backend, interactive=False, stdout=StringIO()) + + # Should have results again now + results = s.search("Hello", models.SearchTest) + self.assertEqual(len(results), 3) + + def test_db_backend(self): + self.test_search(backend='wagtail.wagtailsearch.backends.db.DBSearch') + + def test_elastic_search_backend(self): + backend = find_backend(ElasticSearch) + + if backend is not None: + self.test_search(backend) + else: + print "WARNING: Cannot find an ElasticSearch search backend in configuration. Not testing." + + def test_query_hit_counter(self): + # Add 10 hits to hello query + for i in range(10): + models.Query.get("Hello").add_hit() + + # Check that each hit was registered + self.assertEqual(models.Query.get("Hello").hits, 10) + + def test_query_string_normalisation(self): + # Get a query + query = models.Query.get("Hello World!") + + # Check queries that should be the same + self.assertEqual(query, models.Query.get("Hello World")) + self.assertEqual(query, models.Query.get("Hello World!!")) + self.assertEqual(query, models.Query.get("hello world")) + self.assertEqual(query, models.Query.get("Hello' world")) + + # Check queries that should be different + self.assertNotEqual(query, models.Query.get("HelloWorld")) + self.assertNotEqual(query, models.Query.get("Hello orld!!")) + self.assertNotEqual(query, models.Query.get("Hello")) + + def test_query_popularity(self): + # Add 3 hits to unpopular query + for i in range(3): + models.Query.get("unpopular query").add_hit() + + # Add 10 hits to popular query + for i in range(10): + models.Query.get("popular query").add_hit() + + # Get most popular queries + popular_queries = models.Query.get_most_popular() + + # Check list + self.assertEqual(popular_queries.count(), 2) + self.assertEqual(popular_queries[0], models.Query.get("popular query")) + self.assertEqual(popular_queries[1], models.Query.get("unpopular query")) + + # Add 5 hits to little popular query + for i in range(5): + models.Query.get("little popular query").add_hit() + + # Check list again, little popular query should be in the middle + self.assertEqual(popular_queries.count(), 3) + self.assertEqual(popular_queries[0], models.Query.get("popular query")) + self.assertEqual(popular_queries[1], models.Query.get("little popular query")) + self.assertEqual(popular_queries[2], models.Query.get("unpopular query")) + + # Unpopular query goes viral! + for i in range(20): + models.Query.get("unpopular query").add_hit() + + # Unpopular query should be most popular now + self.assertEqual(popular_queries.count(), 3) + self.assertEqual(popular_queries[0], models.Query.get("unpopular query")) + self.assertEqual(popular_queries[1], models.Query.get("popular query")) + self.assertEqual(popular_queries[2], models.Query.get("little popular query")) + + @unittest.expectedFailure # Time based popularity isn't implemented yet + def test_query_popularity_over_time(self): + today = timezone.now().date() + two_days_ago = today - datetime.timedelta(days=2) + a_week_ago = today - datetime.timedelta(days=7) + a_month_ago = today - datetime.timedelta(days=30) + + # Add 10 hits to a query that was very popular query a month ago + for i in range(10): + models.Query.get("old popular query").add_hit(date=a_month_ago) + + # Add 5 hits to a query that is was popular 2 days ago + for i in range(5): + models.Query.get("new popular query").add_hit(date=two_days_ago) + + # Get most popular queries + popular_queries = models.Query.get_most_popular() + + # Old popular query should be most popular + self.assertEqual(popular_queries.count(), 2) + self.assertEqual(popular_queries[0], models.Query.get("old popular query")) + self.assertEqual(popular_queries[1], models.Query.get("new popular query")) + + # Get most popular queries for past week + past_week_popular_queries = models.Query.get_most_popular(date_since=a_week_ago) + + # Only new popular query should be in this list + self.assertEqual(past_week_popular_queries.count(), 1) + self.assertEqual(past_week_popular_queries[0], models.Query.get("new popular query")) + + # Old popular query gets a couple more hits! + for i in range(2): + models.Query.get("old popular query").add_hit() + + # Old popular query should now be in the most popular queries + self.assertEqual(past_week_popular_queries.count(), 2) + self.assertEqual(past_week_popular_queries[0], models.Query.get("new popular query")) + self.assertEqual(past_week_popular_queries[1], models.Query.get("old popular query")) + + def test_editors_picks(self): + # Get root page + root = core_models.Page.objects.first() + + # Create an editors pick to the root page + models.EditorsPick.objects.create( + query=models.Query.get("root page"), + page=root, + sort_order=0, + description="First editors pick", + ) + + # Get editors pick + self.assertEqual(models.Query.get("root page").editors_picks.count(), 1) + self.assertEqual(models.Query.get("root page").editors_picks.first().page, root) + + # Create a couple more editors picks to test the ordering + models.EditorsPick.objects.create( + query=models.Query.get("root page"), + page=root, + sort_order=2, + description="Last editors pick", + ) + models.EditorsPick.objects.create( + query=models.Query.get("root page"), + page=root, + sort_order=1, + description="Middle editors pick", + ) + + # Check + self.assertEqual(models.Query.get("root page").editors_picks.count(), 3) + self.assertEqual(models.Query.get("root page").editors_picks.first().description, "First editors pick") + self.assertEqual(models.Query.get("root page").editors_picks.last().description, "Last editors pick") + + # Add editors pick with different terms + models.EditorsPick.objects.create( + query=models.Query.get("root page 2"), + page=root, + sort_order=0, + description="Other terms", + ) + + # Check + self.assertEqual(models.Query.get("root page 2").editors_picks.count(), 1) + self.assertEqual(models.Query.get("root page").editors_picks.count(), 3) + + def test_garbage_collect(self): + pass + + def test_suggestions(self): + pass