kopia lustrzana https://dev.funkwhale.audio/funkwhale/funkwhale
See #248: better structure for action serializers
rodzic
107b1ea7dc
commit
bf8b143700
|
@ -1,6 +1,17 @@
|
||||||
|
import collections
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
|
||||||
|
class Action(object):
|
||||||
|
def __init__(self, name, allow_all=False, filters=None):
|
||||||
|
self.name = name
|
||||||
|
self.allow_all = allow_all
|
||||||
|
self.filters = filters or {}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<Action {}>".format(self.name)
|
||||||
|
|
||||||
|
|
||||||
class ActionSerializer(serializers.Serializer):
|
class ActionSerializer(serializers.Serializer):
|
||||||
"""
|
"""
|
||||||
A special serializer that can operate on a list of objects
|
A special serializer that can operate on a list of objects
|
||||||
|
@ -11,19 +22,16 @@ class ActionSerializer(serializers.Serializer):
|
||||||
objects = serializers.JSONField(required=True)
|
objects = serializers.JSONField(required=True)
|
||||||
filters = serializers.DictField(required=False)
|
filters = serializers.DictField(required=False)
|
||||||
actions = None
|
actions = None
|
||||||
filterset_class = None
|
|
||||||
# those are actions identifier where we don't want to allow the "all"
|
|
||||||
# selector because it's to dangerous. Like object deletion.
|
|
||||||
dangerous_actions = []
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.actions_by_name = {a.name: a for a in self.actions}
|
||||||
self.queryset = kwargs.pop("queryset")
|
self.queryset = kwargs.pop("queryset")
|
||||||
if self.actions is None:
|
if self.actions is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must declare a list of actions on " "the serializer class"
|
"You must declare a list of actions on " "the serializer class"
|
||||||
)
|
)
|
||||||
|
|
||||||
for action in self.actions:
|
for action in self.actions_by_name.keys():
|
||||||
handler_name = "handle_{}".format(action)
|
handler_name = "handle_{}".format(action)
|
||||||
assert hasattr(self, handler_name), "{} miss a {} method".format(
|
assert hasattr(self, handler_name), "{} miss a {} method".format(
|
||||||
self.__class__.__name__, handler_name
|
self.__class__.__name__, handler_name
|
||||||
|
@ -31,13 +39,14 @@ class ActionSerializer(serializers.Serializer):
|
||||||
super().__init__(self, *args, **kwargs)
|
super().__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
def validate_action(self, value):
|
def validate_action(self, value):
|
||||||
if value not in self.actions:
|
try:
|
||||||
|
return self.actions_by_name[value]
|
||||||
|
except KeyError:
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
"{} is not a valid action. Pick one of {}.".format(
|
"{} is not a valid action. Pick one of {}.".format(
|
||||||
value, ", ".join(self.actions)
|
value, ", ".join(self.actions_by_name.keys())
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return value
|
|
||||||
|
|
||||||
def validate_objects(self, value):
|
def validate_objects(self, value):
|
||||||
if value == "all":
|
if value == "all":
|
||||||
|
@ -51,15 +60,15 @@ class ActionSerializer(serializers.Serializer):
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self, data):
|
def validate(self, data):
|
||||||
dangerous = data["action"] in self.dangerous_actions
|
allow_all = data["action"].allow_all
|
||||||
if dangerous and self.initial_data["objects"] == "all":
|
if not allow_all and self.initial_data["objects"] == "all":
|
||||||
raise serializers.ValidationError(
|
raise serializers.ValidationError(
|
||||||
"This action is to dangerous to be applied to all objects"
|
"You cannot apply this action on all objects"
|
||||||
)
|
|
||||||
if self.filterset_class and "filters" in data:
|
|
||||||
qs_filterset = self.filterset_class(
|
|
||||||
data["filters"], queryset=data["objects"]
|
|
||||||
)
|
)
|
||||||
|
final_filters = data.get("filters", {}) or {}
|
||||||
|
final_filters.update(data["action"].filters)
|
||||||
|
if self.filterset_class and final_filters:
|
||||||
|
qs_filterset = self.filterset_class(final_filters, queryset=data["objects"])
|
||||||
try:
|
try:
|
||||||
assert qs_filterset.form.is_valid()
|
assert qs_filterset.form.is_valid()
|
||||||
except (AssertionError, TypeError):
|
except (AssertionError, TypeError):
|
||||||
|
@ -72,12 +81,12 @@ class ActionSerializer(serializers.Serializer):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
handler_name = "handle_{}".format(self.validated_data["action"])
|
handler_name = "handle_{}".format(self.validated_data["action"].name)
|
||||||
handler = getattr(self, handler_name)
|
handler = getattr(self, handler_name)
|
||||||
result = handler(self.validated_data["objects"])
|
result = handler(self.validated_data["objects"])
|
||||||
payload = {
|
payload = {
|
||||||
"updated": self.validated_data["count"],
|
"updated": self.validated_data["count"],
|
||||||
"action": self.validated_data["action"],
|
"action": self.validated_data["action"].name,
|
||||||
"result": result,
|
"result": result,
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
|
|
|
@ -769,7 +769,7 @@ class CollectionSerializer(serializers.Serializer):
|
||||||
|
|
||||||
|
|
||||||
class LibraryTrackActionSerializer(common_serializers.ActionSerializer):
|
class LibraryTrackActionSerializer(common_serializers.ActionSerializer):
|
||||||
actions = ["import"]
|
actions = [common_serializers.Action('import', allow_all=True)]
|
||||||
filterset_class = filters.LibraryTrackFilter
|
filterset_class = filters.LibraryTrackFilter
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
|
|
|
@ -61,8 +61,7 @@ class ManageTrackFileSerializer(serializers.ModelSerializer):
|
||||||
|
|
||||||
|
|
||||||
class ManageTrackFileActionSerializer(common_serializers.ActionSerializer):
|
class ManageTrackFileActionSerializer(common_serializers.ActionSerializer):
|
||||||
actions = ["delete"]
|
actions = [common_serializers.Action("delete", allow_all=False)]
|
||||||
dangerous_actions = ["delete"]
|
|
||||||
filterset_class = filters.ManageTrackFileFilterSet
|
filterset_class = filters.ManageTrackFileFilterSet
|
||||||
|
|
||||||
@transaction.atomic
|
@transaction.atomic
|
||||||
|
|
|
@ -11,7 +11,7 @@ class TestActionFilterSet(django_filters.FilterSet):
|
||||||
|
|
||||||
|
|
||||||
class TestSerializer(serializers.ActionSerializer):
|
class TestSerializer(serializers.ActionSerializer):
|
||||||
actions = ["test"]
|
actions = [serializers.Action("test", allow_all=True)]
|
||||||
filterset_class = TestActionFilterSet
|
filterset_class = TestActionFilterSet
|
||||||
|
|
||||||
def handle_test(self, objects):
|
def handle_test(self, objects):
|
||||||
|
@ -19,8 +19,10 @@ class TestSerializer(serializers.ActionSerializer):
|
||||||
|
|
||||||
|
|
||||||
class TestDangerousSerializer(serializers.ActionSerializer):
|
class TestDangerousSerializer(serializers.ActionSerializer):
|
||||||
actions = ["test", "test_dangerous"]
|
actions = [
|
||||||
dangerous_actions = ["test_dangerous"]
|
serializers.Action("test", allow_all=True),
|
||||||
|
serializers.Action("test_dangerous"),
|
||||||
|
]
|
||||||
|
|
||||||
def handle_test(self, objects):
|
def handle_test(self, objects):
|
||||||
pass
|
pass
|
||||||
|
@ -29,6 +31,14 @@ class TestDangerousSerializer(serializers.ActionSerializer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteOnlyInactiveSerializer(serializers.ActionSerializer):
|
||||||
|
actions = [serializers.Action("test", allow_all=True, filters={"is_active": False})]
|
||||||
|
filterset_class = TestActionFilterSet
|
||||||
|
|
||||||
|
def handle_test(self, objects):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_action_serializer_validates_action():
|
def test_action_serializer_validates_action():
|
||||||
data = {"objects": "all", "action": "nope"}
|
data = {"objects": "all", "action": "nope"}
|
||||||
serializer = TestSerializer(data, queryset=models.User.objects.none())
|
serializer = TestSerializer(data, queryset=models.User.objects.none())
|
||||||
|
@ -52,7 +62,7 @@ def test_action_serializers_objects_clean_ids(factories):
|
||||||
data = {"objects": [user1.pk], "action": "test"}
|
data = {"objects": [user1.pk], "action": "test"}
|
||||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||||
|
|
||||||
assert serializer.is_valid() is True
|
assert serializer.is_valid(raise_exception=True) is True
|
||||||
assert list(serializer.validated_data["objects"]) == [user1]
|
assert list(serializer.validated_data["objects"]) == [user1]
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +73,7 @@ def test_action_serializers_objects_clean_all(factories):
|
||||||
data = {"objects": "all", "action": "test"}
|
data = {"objects": "all", "action": "test"}
|
||||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||||
|
|
||||||
assert serializer.is_valid() is True
|
assert serializer.is_valid(raise_exception=True) is True
|
||||||
assert list(serializer.validated_data["objects"]) == [user1, user2]
|
assert list(serializer.validated_data["objects"]) == [user1, user2]
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +85,7 @@ def test_action_serializers_save(factories, mocker):
|
||||||
data = {"objects": "all", "action": "test"}
|
data = {"objects": "all", "action": "test"}
|
||||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||||
|
|
||||||
assert serializer.is_valid() is True
|
assert serializer.is_valid(raise_exception=True) is True
|
||||||
result = serializer.save()
|
result = serializer.save()
|
||||||
assert result == {"updated": 2, "action": "test", "result": {"hello": "world"}}
|
assert result == {"updated": 2, "action": "test", "result": {"hello": "world"}}
|
||||||
handler.assert_called_once()
|
handler.assert_called_once()
|
||||||
|
@ -88,7 +98,7 @@ def test_action_serializers_filterset(factories):
|
||||||
data = {"objects": "all", "action": "test", "filters": {"is_active": True}}
|
data = {"objects": "all", "action": "test", "filters": {"is_active": True}}
|
||||||
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
serializer = TestSerializer(data, queryset=models.User.objects.all())
|
||||||
|
|
||||||
assert serializer.is_valid() is True
|
assert serializer.is_valid(raise_exception=True) is True
|
||||||
assert list(serializer.validated_data["objects"]) == [user2]
|
assert list(serializer.validated_data["objects"]) == [user2]
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,9 +119,14 @@ def test_dangerous_actions_refuses_all(factories):
|
||||||
assert "non_field_errors" in serializer.errors
|
assert "non_field_errors" in serializer.errors
|
||||||
|
|
||||||
|
|
||||||
def test_dangerous_actions_refuses_not_listed(factories):
|
def test_action_serializers_can_require_filter(factories):
|
||||||
factories["users.User"]()
|
user1 = factories["users.User"](is_active=False)
|
||||||
data = {"objects": "all", "action": "test"}
|
factories["users.User"](is_active=True)
|
||||||
serializer = TestDangerousSerializer(data, queryset=models.User.objects.all())
|
|
||||||
|
|
||||||
assert serializer.is_valid() is True
|
data = {"objects": "all", "action": "test"}
|
||||||
|
serializer = TestDeleteOnlyInactiveSerializer(
|
||||||
|
data, queryset=models.User.objects.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert serializer.is_valid(raise_exception=True) is True
|
||||||
|
assert list(serializer.validated_data["objects"]) == [user1]
|
||||||
|
|
Ładowanie…
Reference in New Issue