kopia lustrzana https://github.com/wagtail/wagtail
Merge branch 'feature/search-tests' into feature/search-backends
commit
03e70df24f
|
@ -4,7 +4,9 @@
|
|||
|
||||
from importlib import import_module
|
||||
from django.utils import six
|
||||
import sys
|
||||
from django.conf import settings
|
||||
from base import InvalidSearchBackendError
|
||||
|
||||
|
||||
# Pinched from django 1.7 source code.
|
||||
|
@ -45,8 +47,13 @@ def get_search_backend(backend='default', **kwargs):
|
|||
# Try to get the WAGTAILSEARCH_BACKENDS entry for the given backend name first
|
||||
conf = WAGTAILSEARCH_BACKENDS[backend]
|
||||
except KeyError:
|
||||
raise InvalidSearchBackendError("Could not find backend '%s': %s" % (
|
||||
backend, e))
|
||||
try:
|
||||
# Trying to import the given backend, in case it's a dotted path
|
||||
import_string(backend)
|
||||
except ImportError as e:
|
||||
raise InvalidSearchBackendError("Could not find backend '%s': %s" % (
|
||||
backend, e))
|
||||
params = kwargs
|
||||
else:
|
||||
# Backend is a conf entry
|
||||
params = conf.copy()
|
||||
|
|
|
@ -47,6 +47,13 @@ class DBSearch(BaseSearch):
|
|||
for term in terms:
|
||||
term_query = None
|
||||
for field_name in fields:
|
||||
# Check if the field exists (this will filter out indexed callables)
|
||||
try:
|
||||
model._meta.get_field_by_name(field_name)
|
||||
except:
|
||||
continue
|
||||
|
||||
# Filter on this field
|
||||
field_filter = {'%s__icontains' % field_name: term}
|
||||
if term_query is None:
|
||||
term_query = models.Q(**field_filter)
|
||||
|
|
|
@ -178,9 +178,11 @@ class ElasticSearch(BaseSearch):
|
|||
type_set[obj_type].append(obj.indexed_build_document())
|
||||
|
||||
# Loop through each type and bulk add them
|
||||
results = []
|
||||
for type_name, type_objects in type_set.items():
|
||||
print type_name, len(type_objects)
|
||||
results.append((type_name, len(type_objects)))
|
||||
self.es.bulk_index(self.es_index, type_name, type_objects)
|
||||
return results
|
||||
|
||||
def delete(self, obj):
|
||||
# Object must be a decendant of Indexed and be a django model
|
||||
|
|
|
@ -6,11 +6,11 @@ from wagtail.wagtailsearch import models
|
|||
class Command(NoArgsCommand):
|
||||
def handle_noargs(self, **options):
|
||||
# Clean daily hits
|
||||
print "Cleaning daily hits records... ",
|
||||
self.stdout.write("Cleaning daily hits records... ")
|
||||
models.QueryDailyHits.garbage_collect()
|
||||
print "Done"
|
||||
self.stdout.write("Done")
|
||||
|
||||
# Clean queries
|
||||
print "Cleaning query records... ",
|
||||
self.stdout.write("Cleaning query records... ")
|
||||
models.Query.garbage_collect()
|
||||
print "Done"
|
||||
self.stdout.write("Done")
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from django.core.management.base import NoArgsCommand
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import models
|
||||
|
||||
from wagtail.wagtailsearch import Indexed, get_search_backend
|
||||
|
||||
|
||||
class Command(NoArgsCommand):
|
||||
def handle_noargs(self, **options):
|
||||
class Command(BaseCommand):
|
||||
def handle(self, backend='default', **options):
|
||||
# Print info
|
||||
print "Getting object list"
|
||||
self.stdout.write("Getting object list")
|
||||
|
||||
# Get list of indexed models
|
||||
indexed_models = [model for model in models.get_models() if issubclass(model, Indexed)]
|
||||
|
@ -46,21 +46,24 @@ class Command(NoArgsCommand):
|
|||
object_set[key] = obj
|
||||
|
||||
# Search backend
|
||||
s = get_search_backend()
|
||||
s = get_search_backend(backend=backend)
|
||||
|
||||
# Reset the index
|
||||
print "Reseting index"
|
||||
self.stdout.write("Reseting index")
|
||||
s.reset_index()
|
||||
|
||||
# Add types
|
||||
print "Adding types"
|
||||
self.stdout.write("Adding types")
|
||||
for model in indexed_models:
|
||||
s.add_type(model)
|
||||
|
||||
# Add objects to index
|
||||
print "Adding objects"
|
||||
s.add_bulk(object_set.values())
|
||||
self.stdout.write("Adding objects")
|
||||
results = s.add_bulk(object_set.values())
|
||||
if results:
|
||||
for result in results:
|
||||
self.stdout.write(result[0] + ' ' + str(result[1]))
|
||||
|
||||
# Refresh index
|
||||
print "Refreshing index"
|
||||
self.stdout.write("Refreshing index")
|
||||
s.refresh_index()
|
||||
|
|
|
@ -16,11 +16,16 @@ class Query(models.Model):
|
|||
|
||||
super(Query, self).save(*args, **kwargs)
|
||||
|
||||
def add_hit(self):
|
||||
daily_hits, created = QueryDailyHits.objects.get_or_create(query=self, date=timezone.now().date())
|
||||
def add_hit(self, date=None):
|
||||
if date is None:
|
||||
date = timezone.now().date()
|
||||
daily_hits, created = QueryDailyHits.objects.get_or_create(query=self, date=date)
|
||||
daily_hits.hits = models.F('hits') + 1
|
||||
daily_hits.save()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.query_string
|
||||
|
||||
@property
|
||||
def hits(self):
|
||||
return self.daily_hits.aggregate(models.Sum('hits'))['hits__sum']
|
||||
|
@ -38,6 +43,7 @@ class Query(models.Model):
|
|||
|
||||
@classmethod
|
||||
def get_most_popular(cls, date_since=None):
|
||||
# TODO: Implement date_since
|
||||
return cls.objects.filter(daily_hits__isnull=False).annotate(_hits=models.Sum('daily_hits__hits')).distinct().order_by('-_hits')
|
||||
|
||||
@staticmethod
|
||||
|
@ -49,7 +55,7 @@ class Query(models.Model):
|
|||
query_string = ''.join([c for c in query_string if c not in string.punctuation])
|
||||
|
||||
# Remove double spaces
|
||||
' '.join(query_string.split())
|
||||
query_string = ' '.join(query_string.split())
|
||||
|
||||
return query_string
|
||||
|
||||
|
@ -90,10 +96,18 @@ class SearchTest(models.Model, Indexed):
|
|||
title = models.CharField(max_length=255)
|
||||
content = models.TextField()
|
||||
|
||||
indexed_fields = ("title", "content")
|
||||
indexed_fields = ("title", "content", "callable_indexed_field")
|
||||
|
||||
title_search = Searcher(["title"])
|
||||
|
||||
def object_indexed(self):
|
||||
if self.title == "Don't index me!":
|
||||
return False
|
||||
return True
|
||||
|
||||
def callable_indexed_field(self):
|
||||
return "Callable"
|
||||
|
||||
|
||||
class SearchTestChild(SearchTest):
|
||||
extra_content = models.TextField()
|
||||
|
|
|
@ -8,7 +8,17 @@ class Searcher(object):
|
|||
|
||||
def __get__(self, instance, cls):
|
||||
def dosearch(query_string, **kwargs):
|
||||
# Get backend
|
||||
if 'backend' in kwargs:
|
||||
backend = kwargs['backend']
|
||||
del kwargs['backend']
|
||||
else:
|
||||
backend = 'default'
|
||||
|
||||
# Build search kwargs
|
||||
search_kwargs = dict(model=cls, fields=self.fields, filters=self.filters)
|
||||
search_kwargs.update(kwargs)
|
||||
return get_search_backend().search(query_string, **search_kwargs)
|
||||
|
||||
# Run search
|
||||
return get_search_backend(backend=backend).search(query_string, **search_kwargs)
|
||||
return dosearch
|
||||
|
|
|
@ -1,13 +1,56 @@
|
|||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
from django.core import management
|
||||
from django.conf import settings
|
||||
|
||||
import models
|
||||
from wagtail.wagtailsearch.backends import get_search_backend()
|
||||
import datetime
|
||||
import unittest
|
||||
from StringIO import StringIO
|
||||
|
||||
from wagtail.wagtailcore import models as core_models
|
||||
from wagtail.wagtailsearch import models
|
||||
from wagtail.wagtailsearch.backends import get_search_backend
|
||||
|
||||
from wagtail.wagtailsearch.backends.base import InvalidSearchBackendError
|
||||
from wagtail.wagtailsearch.backends.db import DBSearch
|
||||
from wagtail.wagtailsearch.backends.elasticsearch import ElasticSearch
|
||||
|
||||
|
||||
def find_backend(cls):
|
||||
if not hasattr(settings, 'WAGTAILSEARCH_BACKENDS') and cls == DBSearch:
|
||||
return 'default'
|
||||
|
||||
for backend in settings.WAGTAILSEARCH_BACKENDS.keys():
|
||||
if isinstance(get_search_backend(backend), cls):
|
||||
return backend
|
||||
|
||||
|
||||
class TestSearch(TestCase):
|
||||
def test_search(self):
|
||||
# Create search backend and reset the index
|
||||
s = get_search_backend()
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestSearch, self).__init__(*args, **kwargs)
|
||||
|
||||
self.backends_tested = []
|
||||
|
||||
def test_backend_loader(self):
|
||||
# Test DB backend import
|
||||
db = get_search_backend(backend='wagtail.wagtailsearch.backends.db.DBSearch')
|
||||
self.assertIsInstance(db, DBSearch)
|
||||
|
||||
# Test Elastic search backend import
|
||||
elasticsearch = get_search_backend(backend='wagtail.wagtailsearch.backends.elasticsearch.ElasticSearch')
|
||||
self.assertIsInstance(elasticsearch, ElasticSearch)
|
||||
|
||||
# Test loading a non existant backend
|
||||
self.assertRaises(InvalidSearchBackendError, get_search_backend, backend='wagtail.wagtailsearch.backends.doesntexist.DoesntExist')
|
||||
|
||||
def test_search(self, backend='default'):
|
||||
# Don't test the same backend more than once!
|
||||
if backend in self.backends_tested:
|
||||
return
|
||||
self.backends_tested.append(backend)
|
||||
|
||||
# Get search backend and reset the index
|
||||
s = get_search_backend(backend=backend)
|
||||
s.reset_index()
|
||||
|
||||
# Create a couple of objects and add them to the index
|
||||
|
@ -33,12 +76,29 @@ class TestSearch(TestCase):
|
|||
results = s.search("Hello", models.SearchTest)
|
||||
self.assertEqual(len(results), 3)
|
||||
|
||||
# Ordinary search on "World"
|
||||
# Retrieve single result
|
||||
self.assertIsInstance(results[0], models.SearchTest)
|
||||
|
||||
# Retrieve results through iteration
|
||||
iterations = 0
|
||||
for result in results:
|
||||
self.assertIsInstance(result, models.SearchTest)
|
||||
iterations += 1
|
||||
self.assertEqual(iterations, 3)
|
||||
|
||||
# Retrieve results through slice
|
||||
iterations = 0
|
||||
for result in results[:]:
|
||||
self.assertIsInstance(result, models.SearchTest)
|
||||
iterations += 1
|
||||
self.assertEqual(iterations, 3)
|
||||
|
||||
# Ordinary search on "World"
|
||||
results = s.search("World", models.SearchTest)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
# Searcher search
|
||||
results = models.SearchTest.title_search("Hello")
|
||||
results = models.SearchTest.title_search("Hello", backend=backend)
|
||||
self.assertEqual(len(results), 3)
|
||||
|
||||
# Ordinary search on child
|
||||
|
@ -46,5 +106,178 @@ class TestSearch(TestCase):
|
|||
self.assertEqual(len(results), 1)
|
||||
|
||||
# Searcher search on child
|
||||
results = models.SearchTestChild.title_search("Hello")
|
||||
results = models.SearchTestChild.title_search("Hello", backend=backend)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
# Reset the index, this should clear out the index (but doesn't have to!)
|
||||
s.reset_index()
|
||||
|
||||
# Run update_index command
|
||||
management.call_command('update_index', backend, interactive=False, stdout=StringIO())
|
||||
|
||||
# Should have results again now
|
||||
results = s.search("Hello", models.SearchTest)
|
||||
self.assertEqual(len(results), 3)
|
||||
|
||||
def test_db_backend(self):
|
||||
self.test_search(backend='wagtail.wagtailsearch.backends.db.DBSearch')
|
||||
|
||||
def test_elastic_search_backend(self):
|
||||
backend = find_backend(ElasticSearch)
|
||||
|
||||
if backend is not None:
|
||||
self.test_search(backend)
|
||||
else:
|
||||
print "WARNING: Cannot find an ElasticSearch search backend in configuration. Not testing."
|
||||
|
||||
def test_query_hit_counter(self):
|
||||
# Add 10 hits to hello query
|
||||
for i in range(10):
|
||||
models.Query.get("Hello").add_hit()
|
||||
|
||||
# Check that each hit was registered
|
||||
self.assertEqual(models.Query.get("Hello").hits, 10)
|
||||
|
||||
def test_query_string_normalisation(self):
|
||||
# Get a query
|
||||
query = models.Query.get("Hello World!")
|
||||
|
||||
# Check queries that should be the same
|
||||
self.assertEqual(query, models.Query.get("Hello World"))
|
||||
self.assertEqual(query, models.Query.get("Hello World!!"))
|
||||
self.assertEqual(query, models.Query.get("hello world"))
|
||||
self.assertEqual(query, models.Query.get("Hello' world"))
|
||||
|
||||
# Check queries that should be different
|
||||
self.assertNotEqual(query, models.Query.get("HelloWorld"))
|
||||
self.assertNotEqual(query, models.Query.get("Hello orld!!"))
|
||||
self.assertNotEqual(query, models.Query.get("Hello"))
|
||||
|
||||
def test_query_popularity(self):
|
||||
# Add 3 hits to unpopular query
|
||||
for i in range(3):
|
||||
models.Query.get("unpopular query").add_hit()
|
||||
|
||||
# Add 10 hits to popular query
|
||||
for i in range(10):
|
||||
models.Query.get("popular query").add_hit()
|
||||
|
||||
# Get most popular queries
|
||||
popular_queries = models.Query.get_most_popular()
|
||||
|
||||
# Check list
|
||||
self.assertEqual(popular_queries.count(), 2)
|
||||
self.assertEqual(popular_queries[0], models.Query.get("popular query"))
|
||||
self.assertEqual(popular_queries[1], models.Query.get("unpopular query"))
|
||||
|
||||
# Add 5 hits to little popular query
|
||||
for i in range(5):
|
||||
models.Query.get("little popular query").add_hit()
|
||||
|
||||
# Check list again, little popular query should be in the middle
|
||||
self.assertEqual(popular_queries.count(), 3)
|
||||
self.assertEqual(popular_queries[0], models.Query.get("popular query"))
|
||||
self.assertEqual(popular_queries[1], models.Query.get("little popular query"))
|
||||
self.assertEqual(popular_queries[2], models.Query.get("unpopular query"))
|
||||
|
||||
# Unpopular query goes viral!
|
||||
for i in range(20):
|
||||
models.Query.get("unpopular query").add_hit()
|
||||
|
||||
# Unpopular query should be most popular now
|
||||
self.assertEqual(popular_queries.count(), 3)
|
||||
self.assertEqual(popular_queries[0], models.Query.get("unpopular query"))
|
||||
self.assertEqual(popular_queries[1], models.Query.get("popular query"))
|
||||
self.assertEqual(popular_queries[2], models.Query.get("little popular query"))
|
||||
|
||||
@unittest.expectedFailure # Time based popularity isn't implemented yet
|
||||
def test_query_popularity_over_time(self):
|
||||
today = timezone.now().date()
|
||||
two_days_ago = today - datetime.timedelta(days=2)
|
||||
a_week_ago = today - datetime.timedelta(days=7)
|
||||
a_month_ago = today - datetime.timedelta(days=30)
|
||||
|
||||
# Add 10 hits to a query that was very popular query a month ago
|
||||
for i in range(10):
|
||||
models.Query.get("old popular query").add_hit(date=a_month_ago)
|
||||
|
||||
# Add 5 hits to a query that is was popular 2 days ago
|
||||
for i in range(5):
|
||||
models.Query.get("new popular query").add_hit(date=two_days_ago)
|
||||
|
||||
# Get most popular queries
|
||||
popular_queries = models.Query.get_most_popular()
|
||||
|
||||
# Old popular query should be most popular
|
||||
self.assertEqual(popular_queries.count(), 2)
|
||||
self.assertEqual(popular_queries[0], models.Query.get("old popular query"))
|
||||
self.assertEqual(popular_queries[1], models.Query.get("new popular query"))
|
||||
|
||||
# Get most popular queries for past week
|
||||
past_week_popular_queries = models.Query.get_most_popular(date_since=a_week_ago)
|
||||
|
||||
# Only new popular query should be in this list
|
||||
self.assertEqual(past_week_popular_queries.count(), 1)
|
||||
self.assertEqual(past_week_popular_queries[0], models.Query.get("new popular query"))
|
||||
|
||||
# Old popular query gets a couple more hits!
|
||||
for i in range(2):
|
||||
models.Query.get("old popular query").add_hit()
|
||||
|
||||
# Old popular query should now be in the most popular queries
|
||||
self.assertEqual(past_week_popular_queries.count(), 2)
|
||||
self.assertEqual(past_week_popular_queries[0], models.Query.get("new popular query"))
|
||||
self.assertEqual(past_week_popular_queries[1], models.Query.get("old popular query"))
|
||||
|
||||
def test_editors_picks(self):
|
||||
# Get root page
|
||||
root = core_models.Page.objects.first()
|
||||
|
||||
# Create an editors pick to the root page
|
||||
models.EditorsPick.objects.create(
|
||||
query=models.Query.get("root page"),
|
||||
page=root,
|
||||
sort_order=0,
|
||||
description="First editors pick",
|
||||
)
|
||||
|
||||
# Get editors pick
|
||||
self.assertEqual(models.Query.get("root page").editors_picks.count(), 1)
|
||||
self.assertEqual(models.Query.get("root page").editors_picks.first().page, root)
|
||||
|
||||
# Create a couple more editors picks to test the ordering
|
||||
models.EditorsPick.objects.create(
|
||||
query=models.Query.get("root page"),
|
||||
page=root,
|
||||
sort_order=2,
|
||||
description="Last editors pick",
|
||||
)
|
||||
models.EditorsPick.objects.create(
|
||||
query=models.Query.get("root page"),
|
||||
page=root,
|
||||
sort_order=1,
|
||||
description="Middle editors pick",
|
||||
)
|
||||
|
||||
# Check
|
||||
self.assertEqual(models.Query.get("root page").editors_picks.count(), 3)
|
||||
self.assertEqual(models.Query.get("root page").editors_picks.first().description, "First editors pick")
|
||||
self.assertEqual(models.Query.get("root page").editors_picks.last().description, "Last editors pick")
|
||||
|
||||
# Add editors pick with different terms
|
||||
models.EditorsPick.objects.create(
|
||||
query=models.Query.get("root page 2"),
|
||||
page=root,
|
||||
sort_order=0,
|
||||
description="Other terms",
|
||||
)
|
||||
|
||||
# Check
|
||||
self.assertEqual(models.Query.get("root page 2").editors_picks.count(), 1)
|
||||
self.assertEqual(models.Query.get("root page").editors_picks.count(), 3)
|
||||
|
||||
def test_garbage_collect(self):
|
||||
pass
|
||||
|
||||
def test_suggestions(self):
|
||||
pass
|
||||
|
|
Ładowanie…
Reference in New Issue