diff --git a/wagtail/search/backends/elasticsearch7.py b/wagtail/search/backends/elasticsearch7.py index af425f5c3a..ebf346a325 100644 --- a/wagtail/search/backends/elasticsearch7.py +++ b/wagtail/search/backends/elasticsearch7.py @@ -4,6 +4,7 @@ from copy import deepcopy from urllib.parse import urlparse from django.db import DEFAULT_DB_ALIAS, models +from django.db.models import Subquery from django.db.models.sql import Query from django.db.models.sql.constants import MULTI, SINGLE from django.utils.crypto import get_random_string @@ -505,9 +506,10 @@ class Elasticsearch7SearchQueryCompiler(BaseSearchQueryCompiler): } } else: - if isinstance(value, Query): + if isinstance(value, (Query, Subquery)): db_alias = self.queryset._db or DEFAULT_DB_ALIAS - value = value.get_compiler(db_alias).execute_sql(result_type=SINGLE) + query = value.query if isinstance(value, Subquery) else value + value = query.get_compiler(db_alias).execute_sql(result_type=SINGLE) # The result is either a tuple with one element or None if value: value = value[0] diff --git a/wagtail/search/tests/test_backends.py b/wagtail/search/tests/test_backends.py index 1af759eb71..ecf7d3ae6a 100644 --- a/wagtail/search/tests/test_backends.py +++ b/wagtail/search/tests/test_backends.py @@ -7,6 +7,7 @@ from unittest import mock from django.conf import settings from django.core import management from django.db import connection +from django.db.models import Subquery from django.test import TestCase from django.test.utils import override_settings from taggit.models import Tag @@ -307,16 +308,22 @@ class BackendTests(WagtailTestUtils): .order_by("novel_id") .values_list("pk", flat=True)[:1] ) + cases = { + "implicit": protagonist, + "explicit": Subquery(protagonist), + } - results = self.backend.search( - MATCH_ALL, - models.Novel.objects.filter(protagonist_id=protagonist), - ) + for case, subquery in cases.items(): + with self.subTest(case=case): + results = self.backend.search( + MATCH_ALL, + models.Novel.objects.filter(protagonist_id=subquery), + ) - self.assertUnsortedListEqual( - [r.title for r in results], - ["The Fellowship of the Ring"], - ) + self.assertUnsortedListEqual( + [r.title for r in results], + ["The Fellowship of the Ring"], + ) def test_filter_lt(self): results = self.backend.search(