funkwhale/api/funkwhale_api/common/serializers.py

219 wiersze
7.2 KiB
Python

import collections
import io
import PIL
import os
from rest_framework import serializers
from django.core.exceptions import ObjectDoesNotExist
from django.core.files.uploadedfile import SimpleUploadedFile
from django.utils.encoding import smart_text
from django.utils.translation import ugettext_lazy as _
class RelatedField(serializers.RelatedField):
default_error_messages = {
"does_not_exist": _("Object with {related_field_name}={value} does not exist."),
"invalid": _("Invalid value."),
}
def __init__(self, related_field_name, serializer, **kwargs):
self.related_field_name = related_field_name
self.serializer = serializer
self.filters = kwargs.pop("filters", None)
kwargs["queryset"] = kwargs.pop(
"queryset", self.serializer.Meta.model.objects.all()
)
super().__init__(**kwargs)
def get_filters(self, data):
filters = {self.related_field_name: data}
if self.filters:
filters.update(self.filters(self.context))
return filters
def to_internal_value(self, data):
try:
queryset = self.get_queryset()
filters = self.get_filters(data)
return queryset.get(**filters)
except ObjectDoesNotExist:
self.fail(
"does_not_exist",
related_field_name=self.related_field_name,
value=smart_text(data),
)
except (TypeError, ValueError):
self.fail("invalid")
def to_representation(self, obj):
return self.serializer.to_representation(obj)
def get_choices(self, cutoff=None):
queryset = self.get_queryset()
if queryset is None:
# Ensure that field.choices returns something sensible
# even when accessed with a read-only field.
return {}
if cutoff is not None:
queryset = queryset[:cutoff]
return collections.OrderedDict(
[
(
self.to_representation(item)[self.related_field_name],
self.display_value(item),
)
for item in queryset
]
)
class Action(object):
def __init__(self, name, allow_all=False, qs_filter=None):
self.name = name
self.allow_all = allow_all
self.qs_filter = qs_filter
def __repr__(self):
return "<Action {}>".format(self.name)
class ActionSerializer(serializers.Serializer):
"""
A special serializer that can operate on a list of objects
and apply actions on it.
"""
action = serializers.CharField(required=True)
objects = serializers.JSONField(required=True)
filters = serializers.DictField(required=False)
actions = None
pk_field = "pk"
def __init__(self, *args, **kwargs):
self.actions_by_name = {a.name: a for a in self.actions}
self.queryset = kwargs.pop("queryset")
if self.actions is None:
raise ValueError(
"You must declare a list of actions on " "the serializer class"
)
for action in self.actions_by_name.keys():
handler_name = "handle_{}".format(action)
assert hasattr(self, handler_name), "{} miss a {} method".format(
self.__class__.__name__, handler_name
)
super().__init__(self, *args, **kwargs)
def validate_action(self, value):
try:
return self.actions_by_name[value]
except KeyError:
raise serializers.ValidationError(
"{} is not a valid action. Pick one of {}.".format(
value, ", ".join(self.actions_by_name.keys())
)
)
def validate_objects(self, value):
if value == "all":
return self.queryset.all().order_by("id")
if type(value) in [list, tuple]:
return self.queryset.filter(
**{"{}__in".format(self.pk_field): value}
).order_by(self.pk_field)
raise serializers.ValidationError(
"{} is not a valid value for objects. You must provide either a "
'list of identifiers or the string "all".'.format(value)
)
def validate(self, data):
allow_all = data["action"].allow_all
if not allow_all and self.initial_data["objects"] == "all":
raise serializers.ValidationError(
"You cannot apply this action on all objects"
)
final_filters = data.get("filters", {}) or {}
if self.filterset_class and final_filters:
qs_filterset = self.filterset_class(final_filters, queryset=data["objects"])
try:
assert qs_filterset.form.is_valid()
except (AssertionError, TypeError):
raise serializers.ValidationError("Invalid filters")
data["objects"] = qs_filterset.qs
if data["action"].qs_filter:
data["objects"] = data["action"].qs_filter(data["objects"])
data["count"] = data["objects"].count()
if data["count"] < 1:
raise serializers.ValidationError("No object matching your request")
return data
def save(self):
handler_name = "handle_{}".format(self.validated_data["action"].name)
handler = getattr(self, handler_name)
result = handler(self.validated_data["objects"])
payload = {
"updated": self.validated_data["count"],
"action": self.validated_data["action"].name,
"result": result,
}
return payload
def track_fields_for_update(*fields):
"""
Apply this decorator to serializer to call function when specific values
are updated on an object:
.. code-block:: python
@track_fields_for_update('privacy_level')
class LibrarySerializer(serializers.ModelSerializer):
def on_updated_privacy_level(self, obj, old_value, new_value):
print('Do someting')
"""
def decorator(serializer_class):
original_update = serializer_class.update
def new_update(self, obj, validated_data):
tracked_fields_before = {f: getattr(obj, f) for f in fields}
obj = original_update(self, obj, validated_data)
tracked_fields_after = {f: getattr(obj, f) for f in fields}
if tracked_fields_before != tracked_fields_after:
self.on_updated_fields(obj, tracked_fields_before, tracked_fields_after)
return obj
serializer_class.update = new_update
return serializer_class
return decorator
class StripExifImageField(serializers.ImageField):
def to_internal_value(self, data):
file_obj = super().to_internal_value(data)
image = PIL.Image.open(file_obj)
data = list(image.getdata())
image_without_exif = PIL.Image.new(image.mode, image.size)
image_without_exif.putdata(data)
with io.BytesIO() as output:
image_without_exif.save(
output,
format=PIL.Image.EXTENSION[os.path.splitext(file_obj.name)[-1]],
quality=100,
)
content = output.getvalue()
return SimpleUploadedFile(
file_obj.name, content, content_type=file_obj.content_type
)