diff --git a/api/funkwhale_api/common/serializers.py b/api/funkwhale_api/common/serializers.py index 029338ef9..b3e2d3101 100644 --- a/api/funkwhale_api/common/serializers.py +++ b/api/funkwhale_api/common/serializers.py @@ -1,6 +1,17 @@ +import collections from rest_framework import serializers +class Action(object): + def __init__(self, name, allow_all=False, filters=None): + self.name = name + self.allow_all = allow_all + self.filters = filters or {} + + def __repr__(self): + return "".format(self.name) + + class ActionSerializer(serializers.Serializer): """ A special serializer that can operate on a list of objects @@ -11,19 +22,16 @@ class ActionSerializer(serializers.Serializer): objects = serializers.JSONField(required=True) filters = serializers.DictField(required=False) actions = None - filterset_class = None - # those are actions identifier where we don't want to allow the "all" - # selector because it's to dangerous. Like object deletion. - dangerous_actions = [] def __init__(self, *args, **kwargs): + self.actions_by_name = {a.name: a for a in self.actions} self.queryset = kwargs.pop("queryset") if self.actions is None: raise ValueError( "You must declare a list of actions on " "the serializer class" ) - for action in self.actions: + for action in self.actions_by_name.keys(): handler_name = "handle_{}".format(action) assert hasattr(self, handler_name), "{} miss a {} method".format( self.__class__.__name__, handler_name @@ -31,13 +39,14 @@ class ActionSerializer(serializers.Serializer): super().__init__(self, *args, **kwargs) def validate_action(self, value): - if value not in self.actions: + try: + return self.actions_by_name[value] + except KeyError: raise serializers.ValidationError( "{} is not a valid action. Pick one of {}.".format( - value, ", ".join(self.actions) + value, ", ".join(self.actions_by_name.keys()) ) ) - return value def validate_objects(self, value): if value == "all": @@ -51,15 +60,15 @@ class ActionSerializer(serializers.Serializer): ) def validate(self, data): - dangerous = data["action"] in self.dangerous_actions - if dangerous and self.initial_data["objects"] == "all": + allow_all = data["action"].allow_all + if not allow_all and self.initial_data["objects"] == "all": raise serializers.ValidationError( - "This action is to dangerous to be applied to all objects" - ) - if self.filterset_class and "filters" in data: - qs_filterset = self.filterset_class( - data["filters"], queryset=data["objects"] + "You cannot apply this action on all objects" ) + final_filters = data.get("filters", {}) or {} + final_filters.update(data["action"].filters) + if self.filterset_class and final_filters: + qs_filterset = self.filterset_class(final_filters, queryset=data["objects"]) try: assert qs_filterset.form.is_valid() except (AssertionError, TypeError): @@ -72,12 +81,12 @@ class ActionSerializer(serializers.Serializer): return data def save(self): - handler_name = "handle_{}".format(self.validated_data["action"]) + handler_name = "handle_{}".format(self.validated_data["action"].name) handler = getattr(self, handler_name) result = handler(self.validated_data["objects"]) payload = { "updated": self.validated_data["count"], - "action": self.validated_data["action"], + "action": self.validated_data["action"].name, "result": result, } return payload diff --git a/api/funkwhale_api/federation/serializers.py b/api/funkwhale_api/federation/serializers.py index 062f74f47..451c199ce 100644 --- a/api/funkwhale_api/federation/serializers.py +++ b/api/funkwhale_api/federation/serializers.py @@ -769,7 +769,7 @@ class CollectionSerializer(serializers.Serializer): class LibraryTrackActionSerializer(common_serializers.ActionSerializer): - actions = ["import"] + actions = [common_serializers.Action('import', allow_all=True)] filterset_class = filters.LibraryTrackFilter @transaction.atomic diff --git a/api/funkwhale_api/manage/serializers.py b/api/funkwhale_api/manage/serializers.py index e8f1e328e..f5d52bcac 100644 --- a/api/funkwhale_api/manage/serializers.py +++ b/api/funkwhale_api/manage/serializers.py @@ -61,8 +61,7 @@ class ManageTrackFileSerializer(serializers.ModelSerializer): class ManageTrackFileActionSerializer(common_serializers.ActionSerializer): - actions = ["delete"] - dangerous_actions = ["delete"] + actions = [common_serializers.Action("delete", allow_all=False)] filterset_class = filters.ManageTrackFileFilterSet @transaction.atomic diff --git a/api/tests/common/test_serializers.py b/api/tests/common/test_serializers.py index ca5e5ad8f..dbbd38a0d 100644 --- a/api/tests/common/test_serializers.py +++ b/api/tests/common/test_serializers.py @@ -11,7 +11,7 @@ class TestActionFilterSet(django_filters.FilterSet): class TestSerializer(serializers.ActionSerializer): - actions = ["test"] + actions = [serializers.Action("test", allow_all=True)] filterset_class = TestActionFilterSet def handle_test(self, objects): @@ -19,8 +19,10 @@ class TestSerializer(serializers.ActionSerializer): class TestDangerousSerializer(serializers.ActionSerializer): - actions = ["test", "test_dangerous"] - dangerous_actions = ["test_dangerous"] + actions = [ + serializers.Action("test", allow_all=True), + serializers.Action("test_dangerous"), + ] def handle_test(self, objects): pass @@ -29,6 +31,14 @@ class TestDangerousSerializer(serializers.ActionSerializer): pass +class TestDeleteOnlyInactiveSerializer(serializers.ActionSerializer): + actions = [serializers.Action("test", allow_all=True, filters={"is_active": False})] + filterset_class = TestActionFilterSet + + def handle_test(self, objects): + pass + + def test_action_serializer_validates_action(): data = {"objects": "all", "action": "nope"} serializer = TestSerializer(data, queryset=models.User.objects.none()) @@ -52,7 +62,7 @@ def test_action_serializers_objects_clean_ids(factories): data = {"objects": [user1.pk], "action": "test"} serializer = TestSerializer(data, queryset=models.User.objects.all()) - assert serializer.is_valid() is True + assert serializer.is_valid(raise_exception=True) is True assert list(serializer.validated_data["objects"]) == [user1] @@ -63,7 +73,7 @@ def test_action_serializers_objects_clean_all(factories): data = {"objects": "all", "action": "test"} serializer = TestSerializer(data, queryset=models.User.objects.all()) - assert serializer.is_valid() is True + assert serializer.is_valid(raise_exception=True) is True assert list(serializer.validated_data["objects"]) == [user1, user2] @@ -75,7 +85,7 @@ def test_action_serializers_save(factories, mocker): data = {"objects": "all", "action": "test"} serializer = TestSerializer(data, queryset=models.User.objects.all()) - assert serializer.is_valid() is True + assert serializer.is_valid(raise_exception=True) is True result = serializer.save() assert result == {"updated": 2, "action": "test", "result": {"hello": "world"}} handler.assert_called_once() @@ -88,7 +98,7 @@ def test_action_serializers_filterset(factories): data = {"objects": "all", "action": "test", "filters": {"is_active": True}} serializer = TestSerializer(data, queryset=models.User.objects.all()) - assert serializer.is_valid() is True + assert serializer.is_valid(raise_exception=True) is True assert list(serializer.validated_data["objects"]) == [user2] @@ -109,9 +119,14 @@ def test_dangerous_actions_refuses_all(factories): assert "non_field_errors" in serializer.errors -def test_dangerous_actions_refuses_not_listed(factories): - factories["users.User"]() - data = {"objects": "all", "action": "test"} - serializer = TestDangerousSerializer(data, queryset=models.User.objects.all()) +def test_action_serializers_can_require_filter(factories): + user1 = factories["users.User"](is_active=False) + factories["users.User"](is_active=True) - assert serializer.is_valid() is True + data = {"objects": "all", "action": "test"} + serializer = TestDeleteOnlyInactiveSerializer( + data, queryset=models.User.objects.all() + ) + + assert serializer.is_valid(raise_exception=True) is True + assert list(serializer.validated_data["objects"]) == [user1]