From 1130209823bf351d7049cbb7fcba1f0dd66025f0 Mon Sep 17 00:00:00 2001 From: Bertrand Bordage Date: Sat, 14 Jul 2018 14:54:23 +0200 Subject: [PATCH] Builds a single tsquery to handle complex negations. --- wagtail/contrib/postgres_search/backend.py | 86 +++++++++++++--------- wagtail/contrib/postgres_search/models.py | 19 ++++- 2 files changed, 68 insertions(+), 37 deletions(-) diff --git a/wagtail/contrib/postgres_search/backend.py b/wagtail/contrib/postgres_search/backend.py index e2b1dcc225..e3d05d275c 100644 --- a/wagtail/contrib/postgres_search/backend.py +++ b/wagtail/contrib/postgres_search/backend.py @@ -1,6 +1,5 @@ from collections import OrderedDict -from django.contrib.postgres.search import SearchQuery as PostgresSearchQuery from django.contrib.postgres.search import SearchRank, SearchVector from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction from django.db.models import Count, F, Manager, Q, TextField, Value @@ -12,9 +11,9 @@ from wagtail.search.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults, FilterFieldError) from wagtail.search.index import RelatedFields, SearchField, get_indexed_models from wagtail.search.query import And, Boost, MatchAll, Not, Or, PlainText -from wagtail.search.utils import ADD, AND, MUL, OR +from wagtail.search.utils import ADD, MUL, OR -from .models import SearchAutocomplete as PostgresSearchAutocomplete +from .models import RawSearchQuery as PostgresRawSearchQuery from .models import IndexEntry from .utils import ( get_content_type_pk, get_descendants_content_types_pks, get_postgresql_connections, @@ -189,11 +188,13 @@ class Index: class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): DEFAULT_OPERATOR = 'and' - OPERATORS = { - 'and': AND, - 'or': OR, + TSQUERY_AND = ' & ' + TSQUERY_OR = ' | ' + TSQUERY_OPERATORS = { + 'and': TSQUERY_AND, + 'or': TSQUERY_OR, } - query_class = PostgresSearchQuery + TSQUERY_WORD_FORMAT = "'%s'" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -226,50 +227,69 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): and field.field_name == field_lookup: return self.get_search_field(sub_field_name, field.fields) - def build_database_query(self, query, config=None): - if isinstance(query, PlainText): - operator = self.OPERATORS[query.operator] + def prepare_word(self, word): + return unidecode(word) - return operator([ - self.query_class(unidecode(term), config=config) - for term in query.query_string.split() - ]) + def build_tsquery_content(self, query, group=False): + if isinstance(query, PlainText): + query_formats = [] + query_params = [] + for word in query.query_string.split(): + query_formats.append(self.TSQUERY_WORD_FORMAT) + query_params.append(self.prepare_word(word)) + operator = self.TSQUERY_OPERATORS[query.operator] + query_format = operator.join(query_formats) + if group and len(query_formats) > 1: + query_format = '(%s)' % query_format + return query_format, query_params if isinstance(query, Boost): - return self.build_database_query(query.subquery, config=config) + return self.build_tsquery_content(query.subquery) if isinstance(query, Not): - return ~self.build_database_query(query.subquery, config=config) - if isinstance(query, And): - return AND(self.build_database_query(subquery, config=config) - for subquery in query.subqueries) - if isinstance(query, Or): - return OR(self.build_database_query(subquery, config=config) - for subquery in query.subqueries) + query_format, query_params = \ + self.build_tsquery_content(query.subquery, group=True) + return '!' + query_format, query_params + if isinstance(query, (And, Or)): + query_formats = [] + query_params = [] + for subquery in query.subqueries: + subquery_format, subquery_params = \ + self.build_tsquery_content(subquery, group=True) + query_formats.append(subquery_format) + query_params.extend(subquery_params) + operator = (self.TSQUERY_AND if isinstance(query, And) + else self.TSQUERY_OR) + return operator.join(query_formats), query_params raise NotImplementedError( '`%s` is not supported by the PostgreSQL search backend.' % query.__class__.__name__) - def build_database_rank(self, vector, query, config=None, boost=1.0): + def build_tsquery(self, query, config=None): + query_format, query_params = self.build_tsquery_content(query) + return PostgresRawSearchQuery(query_format, query_params, + config=config) + + def build_tsrank(self, vector, query, config=None, boost=1.0): if isinstance(query, (PlainText, Not)): rank_expression = SearchRank( vector, - self.build_database_query(query, config=config), + self.build_tsquery(query, config=config), weights=self.sql_weights) if boost != 1.0: rank_expression *= boost return rank_expression if isinstance(query, Boost): boost *= query.boost - return self.build_database_rank(vector, query.subquery, - config=config, boost=boost) + return self.build_tsrank(vector, query.subquery, + config=config, boost=boost) if isinstance(query, And): return MUL( - 1 + self.build_database_rank(vector, subquery, - config=config, boost=boost) + 1 + self.build_tsrank(vector, subquery, + config=config, boost=boost) for subquery in query.subqueries) - 1 if isinstance(query, Or): return ADD( - self.build_database_rank(vector, subquery, - config=config, boost=boost) + self.build_tsrank(vector, subquery, + config=config, boost=boost) for subquery in query.subqueries) / (len(query.subqueries) or 1) raise NotImplementedError( '`%s` is not supported by the PostgreSQL search backend.' @@ -294,9 +314,9 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): if isinstance(self.query, MatchAll): return self.queryset[start:stop] - search_query = self.build_database_query(self.query, config=config) + search_query = self.build_tsquery(self.query, config=config) vector = self.get_search_vector(search_query) - rank_expression = self.build_database_rank(vector, self.query, config=config) + rank_expression = self.build_tsrank(vector, self.query, config=config) queryset = self.queryset.annotate( _vector_=vector).filter(_vector_=search_query) if self.order_by_relevance: @@ -328,7 +348,7 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): class PostgresAutocompleteQueryCompiler(PostgresSearchQueryCompiler): - query_class = PostgresSearchAutocomplete + TSQUERY_WORD_FORMAT = "'%s':*" def get_index_vector(self, search_query): return F('index_entries__autocomplete') diff --git a/wagtail/contrib/postgres_search/models.py b/wagtail/contrib/postgres_search/models.py index 1477f49459..2b0d44476c 100644 --- a/wagtail/contrib/postgres_search/models.py +++ b/wagtail/contrib/postgres_search/models.py @@ -12,19 +12,30 @@ from wagtail.search.index import class_is_indexed from .utils import get_descendants_content_types_pks -class SearchAutocomplete(SearchQuery): +class RawSearchQuery(SearchQuery): + def __init__(self, format, *args, **kwargs): + self.format = format + super().__init__(*args, **kwargs) + def as_sql(self, compiler, connection): - params = [self.value.replace("'", "''")] + params = [v.replace("'", "''") for v in self.value] if self.config: config_sql, config_params = compiler.compile(self.config) - template = "to_tsquery({}::regconfig, ''%s':*')".format(config_sql) + template = "to_tsquery(%s::regconfig, '%s')" % (config_sql, self.format) params = config_params + params else: - template = "to_tsquery(''%s':*')" + template = "to_tsquery('%s')" % self.format if self.invert: template = '!!({})'.format(template) return template, params + def __invert__(self): + extra = { + 'invert': not self.invert, + 'config': self.config, + } + return type(self)(self.format, self.value, **extra) + class TextIDGenericRelation(GenericRelation): auto_created = True