pull/12916/merge
Jigyasu Rajput 2025-04-22 21:34:56 +00:00 zatwierdzone przez GitHub
commit 9d6870f948
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: B5690EEEBB952194
2 zmienionych plików z 140 dodań i 17 usunięć

Wyświetl plik

@ -308,8 +308,25 @@ class ReferenceIndex(models.Model):
content_path (str): The path to the piece of content on the source
object instance where the reference was found
"""
# Extract references from fields
for field in object._meta.get_fields():
for field in object._meta.get_fields(include_hidden=True):
# First - Extract references from fields with custom extract_references method
if hasattr(field, "extract_references"):
value = field.value_from_object(object)
if value is not None:
yield from (
(
cls._get_base_content_type(to_model).id,
str(to_object_id),
f"{field.name}.{model_path}",
f"{field.name}.{content_path}",
)
for to_model, to_object_id, model_path, content_path in field.extract_references(
value
)
)
continue
# Second - Process many-to-one relations for fields without extract_references
if field.is_relation and field.many_to_one:
if getattr(field, "wagtail_reference_index_ignore", False):
continue
@ -357,21 +374,6 @@ class ReferenceIndex(models.Model):
field.name,
)
if hasattr(field, "extract_references"):
value = field.value_from_object(object)
if value is not None:
yield from (
(
cls._get_base_content_type(to_model).id,
to_object_id,
f"{field.name}.{model_path}",
f"{field.name}.{content_path}",
)
for to_model, to_object_id, model_path, content_path in field.extract_references(
value
)
)
# Extract references from child relations
if isinstance(object, ClusterableModel):
for child_relation in get_all_child_relations(object):

Wyświetl plik

@ -1,4 +1,5 @@
from io import StringIO
from unittest.mock import Mock
from django.contrib.contenttypes.models import ContentType
from django.core import management
@ -526,3 +527,123 @@ class TestDescribeOnDelete(TestCase):
reference.describe_on_delete(),
"the advert placement will also be deleted",
)
class TestCustomExtractReferences(TestCase):
"""
Tests for the custom extract_references functionality in ReferenceIndex.
This test class verifies that:
1. Fields with custom extract_references methods take precedence over default many-to-one handling
2. Empty results from extract_references are handled properly
3. Fields without extract_references still use the default many-to-one handling
"""
def setUp(self):
self.root_page = Page.objects.get(id=2)
self.test_page = EventPage(
title="Test Page",
slug="test-page",
location="the moon",
audience="public",
cost="free",
date_from="2001-01-01",
)
self.root_page.add_child(instance=self.test_page)
def create_test_model(self, field):
"""Helper to create a test model instance with proper _meta by using a mock."""
meta = Mock()
meta.get_fields.return_value = [field]
meta.get_parent_list.return_value = []
test_model = Mock()
test_model._meta = meta
return test_model
def test_custom_extract_precedence(self):
"""Test that custom extract_references takes precedence over many-to-one."""
class CustomField:
def __init__(self, name, test_page):
self.name = name
self.is_relation = True
self.many_to_one = True
self.related_model = EventPage
self.test_page = test_page
def value_from_object(self, instance):
return self.test_page
def extract_references(self, value):
return [(EventPage, str(value.id), "custom_path", "custom_content")]
field = CustomField("test_field", self.test_page)
test_obj = self.create_test_model(field)
references = set(ReferenceIndex._extract_references_from_object(test_obj))
custom_ref = (
ReferenceIndex._get_base_content_type(EventPage).id,
str(self.test_page.id),
"test_field.custom_path",
"test_field.custom_content",
)
default_ref = (
ReferenceIndex._get_base_content_type(EventPage).id,
str(self.test_page.id),
"test_field",
"test_field",
)
self.assertIn(custom_ref, references)
self.assertNotIn(default_ref, references)
self.assertEqual(len(references), 1)
def test_custom_extract_returns_empty(self):
"""Test various empty return values from extract_references."""
class CustomField:
def __init__(self, name, test_page):
self.name = name
self.is_relation = True
self.many_to_one = True
self.related_model = EventPage
self.test_page = test_page
def value_from_object(self, instance):
return self.test_page
def extract_references(self, value):
return []
field = CustomField("test_field", self.test_page)
test_obj = self.create_test_model(field)
references = set(ReferenceIndex._extract_references_from_object(test_obj))
self.assertEqual(len(references), 0)
def test_default_handling_without_custom_extract(self):
"""Test default many-to-one handling without extract_references."""
class DefaultField:
def __init__(self, name, test_page):
self.name = name
self.is_relation = True
self.many_to_one = True
self.related_model = EventPage
self.test_page = test_page
def value_from_object(self, instance):
return self.test_page.id
field = DefaultField("test_field", self.test_page)
test_obj = self.create_test_model(field)
references = set(ReferenceIndex._extract_references_from_object(test_obj))
expected_ref = (
ReferenceIndex._get_base_content_type(EventPage).id,
str(self.test_page.id),
"test_field",
"test_field",
)
self.assertIn(expected_ref, references)
self.assertEqual(len(references), 1)