See #248: better structure for action serializers

merge-requests/267/head
Eliot Berriot 2018-06-21 19:21:51 +02:00
rodzic 107b1ea7dc
commit bf8b143700
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: DD6965E2476E5C27
4 zmienionych plików z 55 dodań i 32 usunięć

Wyświetl plik

@ -1,6 +1,17 @@
import collections
from rest_framework import serializers from rest_framework import serializers
class Action(object):
def __init__(self, name, allow_all=False, filters=None):
self.name = name
self.allow_all = allow_all
self.filters = filters or {}
def __repr__(self):
return "<Action {}>".format(self.name)
class ActionSerializer(serializers.Serializer): class ActionSerializer(serializers.Serializer):
""" """
A special serializer that can operate on a list of objects A special serializer that can operate on a list of objects
@ -11,19 +22,16 @@ class ActionSerializer(serializers.Serializer):
objects = serializers.JSONField(required=True) objects = serializers.JSONField(required=True)
filters = serializers.DictField(required=False) filters = serializers.DictField(required=False)
actions = None actions = None
filterset_class = None
# those are actions identifier where we don't want to allow the "all"
# selector because it's to dangerous. Like object deletion.
dangerous_actions = []
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.actions_by_name = {a.name: a for a in self.actions}
self.queryset = kwargs.pop("queryset") self.queryset = kwargs.pop("queryset")
if self.actions is None: if self.actions is None:
raise ValueError( raise ValueError(
"You must declare a list of actions on " "the serializer class" "You must declare a list of actions on " "the serializer class"
) )
for action in self.actions: for action in self.actions_by_name.keys():
handler_name = "handle_{}".format(action) handler_name = "handle_{}".format(action)
assert hasattr(self, handler_name), "{} miss a {} method".format( assert hasattr(self, handler_name), "{} miss a {} method".format(
self.__class__.__name__, handler_name self.__class__.__name__, handler_name
@ -31,13 +39,14 @@ class ActionSerializer(serializers.Serializer):
super().__init__(self, *args, **kwargs) super().__init__(self, *args, **kwargs)
def validate_action(self, value): def validate_action(self, value):
if value not in self.actions: try:
return self.actions_by_name[value]
except KeyError:
raise serializers.ValidationError( raise serializers.ValidationError(
"{} is not a valid action. Pick one of {}.".format( "{} is not a valid action. Pick one of {}.".format(
value, ", ".join(self.actions) value, ", ".join(self.actions_by_name.keys())
) )
) )
return value
def validate_objects(self, value): def validate_objects(self, value):
if value == "all": if value == "all":
@ -51,15 +60,15 @@ class ActionSerializer(serializers.Serializer):
) )
def validate(self, data): def validate(self, data):
dangerous = data["action"] in self.dangerous_actions allow_all = data["action"].allow_all
if dangerous and self.initial_data["objects"] == "all": if not allow_all and self.initial_data["objects"] == "all":
raise serializers.ValidationError( raise serializers.ValidationError(
"This action is to dangerous to be applied to all objects" "You cannot apply this action on all objects"
)
if self.filterset_class and "filters" in data:
qs_filterset = self.filterset_class(
data["filters"], queryset=data["objects"]
) )
final_filters = data.get("filters", {}) or {}
final_filters.update(data["action"].filters)
if self.filterset_class and final_filters:
qs_filterset = self.filterset_class(final_filters, queryset=data["objects"])
try: try:
assert qs_filterset.form.is_valid() assert qs_filterset.form.is_valid()
except (AssertionError, TypeError): except (AssertionError, TypeError):
@ -72,12 +81,12 @@ class ActionSerializer(serializers.Serializer):
return data return data
def save(self): def save(self):
handler_name = "handle_{}".format(self.validated_data["action"]) handler_name = "handle_{}".format(self.validated_data["action"].name)
handler = getattr(self, handler_name) handler = getattr(self, handler_name)
result = handler(self.validated_data["objects"]) result = handler(self.validated_data["objects"])
payload = { payload = {
"updated": self.validated_data["count"], "updated": self.validated_data["count"],
"action": self.validated_data["action"], "action": self.validated_data["action"].name,
"result": result, "result": result,
} }
return payload return payload

Wyświetl plik

@ -769,7 +769,7 @@ class CollectionSerializer(serializers.Serializer):
class LibraryTrackActionSerializer(common_serializers.ActionSerializer): class LibraryTrackActionSerializer(common_serializers.ActionSerializer):
actions = ["import"] actions = [common_serializers.Action('import', allow_all=True)]
filterset_class = filters.LibraryTrackFilter filterset_class = filters.LibraryTrackFilter
@transaction.atomic @transaction.atomic

Wyświetl plik

@ -61,8 +61,7 @@ class ManageTrackFileSerializer(serializers.ModelSerializer):
class ManageTrackFileActionSerializer(common_serializers.ActionSerializer): class ManageTrackFileActionSerializer(common_serializers.ActionSerializer):
actions = ["delete"] actions = [common_serializers.Action("delete", allow_all=False)]
dangerous_actions = ["delete"]
filterset_class = filters.ManageTrackFileFilterSet filterset_class = filters.ManageTrackFileFilterSet
@transaction.atomic @transaction.atomic

Wyświetl plik

@ -11,7 +11,7 @@ class TestActionFilterSet(django_filters.FilterSet):
class TestSerializer(serializers.ActionSerializer): class TestSerializer(serializers.ActionSerializer):
actions = ["test"] actions = [serializers.Action("test", allow_all=True)]
filterset_class = TestActionFilterSet filterset_class = TestActionFilterSet
def handle_test(self, objects): def handle_test(self, objects):
@ -19,8 +19,10 @@ class TestSerializer(serializers.ActionSerializer):
class TestDangerousSerializer(serializers.ActionSerializer): class TestDangerousSerializer(serializers.ActionSerializer):
actions = ["test", "test_dangerous"] actions = [
dangerous_actions = ["test_dangerous"] serializers.Action("test", allow_all=True),
serializers.Action("test_dangerous"),
]
def handle_test(self, objects): def handle_test(self, objects):
pass pass
@ -29,6 +31,14 @@ class TestDangerousSerializer(serializers.ActionSerializer):
pass pass
class TestDeleteOnlyInactiveSerializer(serializers.ActionSerializer):
actions = [serializers.Action("test", allow_all=True, filters={"is_active": False})]
filterset_class = TestActionFilterSet
def handle_test(self, objects):
pass
def test_action_serializer_validates_action(): def test_action_serializer_validates_action():
data = {"objects": "all", "action": "nope"} data = {"objects": "all", "action": "nope"}
serializer = TestSerializer(data, queryset=models.User.objects.none()) serializer = TestSerializer(data, queryset=models.User.objects.none())
@ -52,7 +62,7 @@ def test_action_serializers_objects_clean_ids(factories):
data = {"objects": [user1.pk], "action": "test"} data = {"objects": [user1.pk], "action": "test"}
serializer = TestSerializer(data, queryset=models.User.objects.all()) serializer = TestSerializer(data, queryset=models.User.objects.all())
assert serializer.is_valid() is True assert serializer.is_valid(raise_exception=True) is True
assert list(serializer.validated_data["objects"]) == [user1] assert list(serializer.validated_data["objects"]) == [user1]
@ -63,7 +73,7 @@ def test_action_serializers_objects_clean_all(factories):
data = {"objects": "all", "action": "test"} data = {"objects": "all", "action": "test"}
serializer = TestSerializer(data, queryset=models.User.objects.all()) serializer = TestSerializer(data, queryset=models.User.objects.all())
assert serializer.is_valid() is True assert serializer.is_valid(raise_exception=True) is True
assert list(serializer.validated_data["objects"]) == [user1, user2] assert list(serializer.validated_data["objects"]) == [user1, user2]
@ -75,7 +85,7 @@ def test_action_serializers_save(factories, mocker):
data = {"objects": "all", "action": "test"} data = {"objects": "all", "action": "test"}
serializer = TestSerializer(data, queryset=models.User.objects.all()) serializer = TestSerializer(data, queryset=models.User.objects.all())
assert serializer.is_valid() is True assert serializer.is_valid(raise_exception=True) is True
result = serializer.save() result = serializer.save()
assert result == {"updated": 2, "action": "test", "result": {"hello": "world"}} assert result == {"updated": 2, "action": "test", "result": {"hello": "world"}}
handler.assert_called_once() handler.assert_called_once()
@ -88,7 +98,7 @@ def test_action_serializers_filterset(factories):
data = {"objects": "all", "action": "test", "filters": {"is_active": True}} data = {"objects": "all", "action": "test", "filters": {"is_active": True}}
serializer = TestSerializer(data, queryset=models.User.objects.all()) serializer = TestSerializer(data, queryset=models.User.objects.all())
assert serializer.is_valid() is True assert serializer.is_valid(raise_exception=True) is True
assert list(serializer.validated_data["objects"]) == [user2] assert list(serializer.validated_data["objects"]) == [user2]
@ -109,9 +119,14 @@ def test_dangerous_actions_refuses_all(factories):
assert "non_field_errors" in serializer.errors assert "non_field_errors" in serializer.errors
def test_dangerous_actions_refuses_not_listed(factories): def test_action_serializers_can_require_filter(factories):
factories["users.User"]() user1 = factories["users.User"](is_active=False)
data = {"objects": "all", "action": "test"} factories["users.User"](is_active=True)
serializer = TestDangerousSerializer(data, queryset=models.User.objects.all())
assert serializer.is_valid() is True data = {"objects": "all", "action": "test"}
serializer = TestDeleteOnlyInactiveSerializer(
data, queryset=models.User.objects.all()
)
assert serializer.is_valid(raise_exception=True) is True
assert list(serializer.validated_data["objects"]) == [user1]