Builds a single tsquery to handle complex negations.

pull/4689/head
Bertrand Bordage 2018-07-14 14:54:23 +02:00
rodzic f0d25f7443
commit 1130209823
2 zmienionych plików z 68 dodań i 37 usunięć

Wyświetl plik

@ -1,6 +1,5 @@
from collections import OrderedDict
from django.contrib.postgres.search import SearchQuery as PostgresSearchQuery
from django.contrib.postgres.search import SearchRank, SearchVector
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
from django.db.models import Count, F, Manager, Q, TextField, Value
@ -12,9 +11,9 @@ from wagtail.search.backends.base import (
BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults, FilterFieldError)
from wagtail.search.index import RelatedFields, SearchField, get_indexed_models
from wagtail.search.query import And, Boost, MatchAll, Not, Or, PlainText
from wagtail.search.utils import ADD, AND, MUL, OR
from wagtail.search.utils import ADD, MUL, OR
from .models import SearchAutocomplete as PostgresSearchAutocomplete
from .models import RawSearchQuery as PostgresRawSearchQuery
from .models import IndexEntry
from .utils import (
get_content_type_pk, get_descendants_content_types_pks, get_postgresql_connections,
@ -189,11 +188,13 @@ class Index:
class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
DEFAULT_OPERATOR = 'and'
OPERATORS = {
'and': AND,
'or': OR,
TSQUERY_AND = ' & '
TSQUERY_OR = ' | '
TSQUERY_OPERATORS = {
'and': TSQUERY_AND,
'or': TSQUERY_OR,
}
query_class = PostgresSearchQuery
TSQUERY_WORD_FORMAT = "'%s'"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -226,50 +227,69 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
and field.field_name == field_lookup:
return self.get_search_field(sub_field_name, field.fields)
def build_database_query(self, query, config=None):
if isinstance(query, PlainText):
operator = self.OPERATORS[query.operator]
def prepare_word(self, word):
return unidecode(word)
return operator([
self.query_class(unidecode(term), config=config)
for term in query.query_string.split()
])
def build_tsquery_content(self, query, group=False):
if isinstance(query, PlainText):
query_formats = []
query_params = []
for word in query.query_string.split():
query_formats.append(self.TSQUERY_WORD_FORMAT)
query_params.append(self.prepare_word(word))
operator = self.TSQUERY_OPERATORS[query.operator]
query_format = operator.join(query_formats)
if group and len(query_formats) > 1:
query_format = '(%s)' % query_format
return query_format, query_params
if isinstance(query, Boost):
return self.build_database_query(query.subquery, config=config)
return self.build_tsquery_content(query.subquery)
if isinstance(query, Not):
return ~self.build_database_query(query.subquery, config=config)
if isinstance(query, And):
return AND(self.build_database_query(subquery, config=config)
for subquery in query.subqueries)
if isinstance(query, Or):
return OR(self.build_database_query(subquery, config=config)
for subquery in query.subqueries)
query_format, query_params = \
self.build_tsquery_content(query.subquery, group=True)
return '!' + query_format, query_params
if isinstance(query, (And, Or)):
query_formats = []
query_params = []
for subquery in query.subqueries:
subquery_format, subquery_params = \
self.build_tsquery_content(subquery, group=True)
query_formats.append(subquery_format)
query_params.extend(subquery_params)
operator = (self.TSQUERY_AND if isinstance(query, And)
else self.TSQUERY_OR)
return operator.join(query_formats), query_params
raise NotImplementedError(
'`%s` is not supported by the PostgreSQL search backend.'
% query.__class__.__name__)
def build_database_rank(self, vector, query, config=None, boost=1.0):
def build_tsquery(self, query, config=None):
query_format, query_params = self.build_tsquery_content(query)
return PostgresRawSearchQuery(query_format, query_params,
config=config)
def build_tsrank(self, vector, query, config=None, boost=1.0):
if isinstance(query, (PlainText, Not)):
rank_expression = SearchRank(
vector,
self.build_database_query(query, config=config),
self.build_tsquery(query, config=config),
weights=self.sql_weights)
if boost != 1.0:
rank_expression *= boost
return rank_expression
if isinstance(query, Boost):
boost *= query.boost
return self.build_database_rank(vector, query.subquery,
config=config, boost=boost)
return self.build_tsrank(vector, query.subquery,
config=config, boost=boost)
if isinstance(query, And):
return MUL(
1 + self.build_database_rank(vector, subquery,
config=config, boost=boost)
1 + self.build_tsrank(vector, subquery,
config=config, boost=boost)
for subquery in query.subqueries) - 1
if isinstance(query, Or):
return ADD(
self.build_database_rank(vector, subquery,
config=config, boost=boost)
self.build_tsrank(vector, subquery,
config=config, boost=boost)
for subquery in query.subqueries) / (len(query.subqueries) or 1)
raise NotImplementedError(
'`%s` is not supported by the PostgreSQL search backend.'
@ -294,9 +314,9 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
if isinstance(self.query, MatchAll):
return self.queryset[start:stop]
search_query = self.build_database_query(self.query, config=config)
search_query = self.build_tsquery(self.query, config=config)
vector = self.get_search_vector(search_query)
rank_expression = self.build_database_rank(vector, self.query, config=config)
rank_expression = self.build_tsrank(vector, self.query, config=config)
queryset = self.queryset.annotate(
_vector_=vector).filter(_vector_=search_query)
if self.order_by_relevance:
@ -328,7 +348,7 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
class PostgresAutocompleteQueryCompiler(PostgresSearchQueryCompiler):
query_class = PostgresSearchAutocomplete
TSQUERY_WORD_FORMAT = "'%s':*"
def get_index_vector(self, search_query):
return F('index_entries__autocomplete')

Wyświetl plik

@ -12,19 +12,30 @@ from wagtail.search.index import class_is_indexed
from .utils import get_descendants_content_types_pks
class SearchAutocomplete(SearchQuery):
class RawSearchQuery(SearchQuery):
def __init__(self, format, *args, **kwargs):
self.format = format
super().__init__(*args, **kwargs)
def as_sql(self, compiler, connection):
params = [self.value.replace("'", "''")]
params = [v.replace("'", "''") for v in self.value]
if self.config:
config_sql, config_params = compiler.compile(self.config)
template = "to_tsquery({}::regconfig, ''%s':*')".format(config_sql)
template = "to_tsquery(%s::regconfig, '%s')" % (config_sql, self.format)
params = config_params + params
else:
template = "to_tsquery(''%s':*')"
template = "to_tsquery('%s')" % self.format
if self.invert:
template = '!!({})'.format(template)
return template, params
def __invert__(self):
extra = {
'invert': not self.invert,
'config': self.config,
}
return type(self)(self.format, self.value, **extra)
class TextIDGenericRelation(GenericRelation):
auto_created = True