diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index e9db7fbf14..29141542a5 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -118,6 +118,8 @@ class ElasticSearchResults(object): self.query = query self.start = 0 self.stop = None + self._results_cache = None + self._count_cache = None def _set_limits(self, start=None, stop=None): if stop is not None: @@ -173,7 +175,7 @@ class ElasticSearchResults(object): # Return results in order given by ElasticSearch return [results[str(pk)] for pk in pks if results[str(pk)]] - def count(self): + def _do_count(self): # Get query query = self.query.to_es() @@ -200,6 +202,19 @@ class ElasticSearchResults(object): return max(hit_count, 0) + def results(self): + if self._results_cache is None: + self._results_cache = self._do_search() + return self._results_cache + + def count(self): + if self._count_cache is None: + if self._results_cache is not None: + self._count_cache = len(self._results_cache) + else: + self._count_cache = self._do_count() + return self._count_cache + def __getitem__(self, key): new = self._clone() @@ -209,17 +224,24 @@ class ElasticSearchResults(object): stop = int(key.stop) if key.stop else None new._set_limits(start, stop) + # Copy results cache + if self._results_cache is not None: + new._results_cache = self._results_cache[key] + return new else: + if self._results_cache is not None: + return self._results_cache[key] + new.start = key new.stop = key + 1 return list(new)[0] def __iter__(self): - return iter(self._do_search()) + return iter(self.results()) def __len__(self): - return len(self._do_search()) + return len(self.results()) def __repr__(self): data = list(self[:21])