diff --git a/wagtail/tests/utils.py b/wagtail/tests/utils.py index da5bfcb3a3..986afd54b3 100644 --- a/wagtail/tests/utils.py +++ b/wagtail/tests/utils.py @@ -136,11 +136,41 @@ class WagtailTestUtils(object): return False return True - def _count_tag_occurrences(self, needle, haystack): - count = 1 if self._tag_is_equal(needle, haystack) else 0 + def _tag_matches_with_extra_attrs(self, thin_tag, fat_tag): + # return true if thin_tag and fat_tag have the same name, + # and all attributes on thin_tag exist on fat_tag + if not hasattr(thin_tag, 'name') or not hasattr(fat_tag, 'name'): + return False + if thin_tag.name != fat_tag.name: + return False + for attr, value in thin_tag.attributes: + if value is None: + # attributes without a value is same as attribute with value that + # equals the attributes name: + # == + if (attr, None) not in fat_tag.attributes and (attr, attr) not in fat_tag.attributes: + return False + else: + if (attr, value) not in fat_tag.attributes: + return False + + return True + + def _count_tag_occurrences(self, needle, haystack, allow_extra_attrs=False): + count = 0 + + if allow_extra_attrs: + if self._tag_matches_with_extra_attrs(needle, haystack): + count += 1 + else: + if self._tag_is_equal(needle, haystack): + count += 1 if hasattr(haystack, 'children'): - count += sum(self._count_tag_occurrences(needle, child) for child in haystack.children) + count += sum( + self._count_tag_occurrences(needle, child, allow_extra_attrs=allow_extra_attrs) + for child in haystack.children + ) return count @@ -160,10 +190,10 @@ class WagtailTestUtils(object): for script_tag in self._find_template_script_tags(child): yield script_tag - def assertTagInHTML(self, needle, haystack, count=None, msg_prefix=''): + def assertTagInHTML(self, needle, haystack, count=None, msg_prefix='', allow_extra_attrs=False): needle = assert_and_parse_html(self, needle, None, 'First argument is not valid HTML:') haystack = assert_and_parse_html(self, haystack, None, 'Second argument is not valid HTML:') - real_count = self._count_tag_occurrences(needle, haystack) + real_count = self._count_tag_occurrences(needle, haystack, allow_extra_attrs=allow_extra_attrs) if count is not None: self.assertEqual( real_count, count, diff --git a/wagtail/wagtailcore/tests/test_tests.py b/wagtail/wagtailcore/tests/test_tests.py index 8756361de4..2f034f0b39 100644 --- a/wagtail/wagtailcore/tests/test_tests.py +++ b/wagtail/wagtailcore/tests/test_tests.py @@ -29,6 +29,20 @@ class TestAssertTagInHTML(TestCase, WagtailTestUtils): with self.assertRaises(AssertionError): self.assertTagInHTML('