diff --git a/wagtail/admin/filters.py b/wagtail/admin/filters.py
index 0f16c19a88..7ac6da2c42 100644
--- a/wagtail/admin/filters.py
+++ b/wagtail/admin/filters.py
@@ -1,4 +1,5 @@
import django_filters
+from django.db import models
from django.utils.translation import gettext_lazy as _
from django_filters.widgets import SuffixedMultiWidget
@@ -38,6 +39,16 @@ class FilteredModelChoiceIterator(django_filters.fields.ModelChoiceIterator):
class FilteredModelChoiceField(django_filters.fields.ModelChoiceField):
+ """
+ A ModelChoiceField that uses FilteredSelect to dynamically show/hide options based on another
+ ModelChoiceField of related objects; an option will be shown whenever the selected related
+ object is present in the result of filter_accessor for that option.
+
+ filter_field - the HTML `id` of the related ModelChoiceField
+ filter_accessor - either the name of a relation, property or method on the model instance which
+ returns a queryset of related objects, or a function which accepts the model instance and
+ returns such a queryset.
+ """
widget = FilteredSelect
iterator = FilteredModelChoiceIterator
@@ -48,13 +59,19 @@ class FilteredModelChoiceField(django_filters.fields.ModelChoiceField):
self.widget.filter_field = filter_field
def get_filter_value(self, obj):
- # filter_accessor identifies a property or method on the instances being listed here,
- # which gives us a queryset of related objects. Turn this queryset into a list of IDs
- # that will become the 'data-filter-value' used to filter this listing
- queryset = getattr(obj, self.filter_accessor)
- if callable(queryset):
- queryset = queryset()
+ # Use filter_accessor to obtain a queryset of related objects
+ if callable(self.filter_accessor):
+ queryset = self.filter_accessor(obj)
+ else:
+ # treat filter_accessor as a method/property name of obj
+ queryset = getattr(obj, self.filter_accessor)
+ if isinstance(queryset, models.Manager):
+ queryset = queryset.all()
+ elif callable(queryset):
+ queryset = queryset()
+ # Turn this queryset into a list of IDs that will become the 'data-filter-value' used to
+ # filter this listing
return queryset.values_list('pk', flat=True)
diff --git a/wagtail/admin/tests/test_filters.py b/wagtail/admin/tests/test_filters.py
new file mode 100644
index 0000000000..90e3f9d8ca
--- /dev/null
+++ b/wagtail/admin/tests/test_filters.py
@@ -0,0 +1,75 @@
+from django import forms
+from django.contrib.auth import get_user_model
+from django.contrib.auth.models import Group
+from django.test import TestCase
+
+from wagtail.admin.filters import FilteredModelChoiceField
+
+
+User = get_user_model()
+
+
+class TestFilteredModelChoiceField(TestCase):
+ def setUp(self):
+ self.musicians = Group.objects.create(name="Musicians")
+ self.actors = Group.objects.create(name="Actors")
+
+ self.david = User.objects.create_user(
+ 'david', 'david@example.com', 'kn1ghtr1der', first_name="David", last_name="Hasselhoff"
+ )
+ self.david.groups.set([self.musicians, self.actors])
+
+ self.kevin = User.objects.create_user(
+ 'kevin', 'kevin@example.com', '6degrees', first_name="Kevin", last_name="Bacon"
+ )
+ self.kevin.groups.set([self.actors])
+
+ self.morten = User.objects.create_user(
+ 'morten', 'morten@example.com', 't4ke0nm3', first_name="Morten", last_name="Harket"
+ )
+ self.morten.groups.set([self.musicians])
+
+ def test_with_relation(self):
+
+ class UserForm(forms.Form):
+ users = FilteredModelChoiceField(
+ queryset=User.objects.order_by('username'), filter_field='id_group', filter_accessor='groups'
+ )
+
+ form = UserForm()
+ html = str(form['users'])
+ expected_html = """
+
+ """ % {
+ 'david': self.david.pk, 'kevin': self.kevin.pk, 'morten': self.morten.pk,
+ 'musicians': self.musicians.pk, 'actors': self.actors.pk,
+ }
+ self.assertHTMLEqual(html, expected_html)
+
+ def test_with_callable(self):
+
+ class UserForm(forms.Form):
+ users = FilteredModelChoiceField(
+ queryset=User.objects.order_by('username'), filter_field='id_group',
+ filter_accessor=lambda user: user.groups.all()
+ )
+
+ form = UserForm()
+ html = str(form['users'])
+ expected_html = """
+
+ """ % {
+ 'david': self.david.pk, 'kevin': self.kevin.pk, 'morten': self.morten.pk,
+ 'musicians': self.musicians.pk, 'actors': self.actors.pk,
+ }
+ self.assertHTMLEqual(html, expected_html)