kopia lustrzana https://dev.funkwhale.audio/funkwhale/funkwhale
172 wiersze
5.4 KiB
Python
172 wiersze
5.4 KiB
Python
import persisting_theory
|
|
from django.db import models, transaction
|
|
from rest_framework import serializers
|
|
|
|
|
|
class ConfNotFound(KeyError):
|
|
pass
|
|
|
|
|
|
class Registry(persisting_theory.Registry):
|
|
look_into = "mutations"
|
|
|
|
def connect(self, type, klass, perm_checkers=None):
|
|
def decorator(serializer_class):
|
|
t = self.setdefault(type, {})
|
|
t[klass] = {
|
|
"serializer_class": serializer_class,
|
|
"perm_checkers": perm_checkers or {},
|
|
}
|
|
return serializer_class
|
|
|
|
return decorator
|
|
|
|
@transaction.atomic
|
|
def apply(self, type, obj, payload):
|
|
conf = self.get_conf(type, obj)
|
|
serializer = conf["serializer_class"](obj, data=payload)
|
|
serializer.is_valid(raise_exception=True)
|
|
previous_state = serializer.get_previous_state(obj, serializer.validated_data)
|
|
serializer.apply(obj, serializer.validated_data)
|
|
return previous_state
|
|
|
|
def is_valid(self, type, obj, payload):
|
|
conf = self.get_conf(type, obj)
|
|
serializer = conf["serializer_class"](obj, data=payload)
|
|
return serializer.is_valid(raise_exception=True)
|
|
|
|
def get_validated_payload(self, type, obj, payload):
|
|
conf = self.get_conf(type, obj)
|
|
serializer = conf["serializer_class"](obj, data=payload)
|
|
serializer.is_valid(raise_exception=True)
|
|
return serializer.payload_serialize(serializer.validated_data)
|
|
|
|
def has_perm(self, perm, type, obj, actor):
|
|
if perm not in ["approve", "suggest"]:
|
|
raise ValueError(f"Invalid permission {perm}")
|
|
conf = self.get_conf(type, obj)
|
|
checker = conf["perm_checkers"].get(perm)
|
|
if not checker:
|
|
return False
|
|
return checker(obj=obj, actor=actor)
|
|
|
|
def get_conf(self, type, obj):
|
|
try:
|
|
type_conf = self[type]
|
|
except KeyError:
|
|
raise ConfNotFound(f"{type} is not a registered mutation")
|
|
|
|
try:
|
|
conf = type_conf[obj.__class__]
|
|
except KeyError:
|
|
try:
|
|
conf = type_conf[None]
|
|
except KeyError:
|
|
raise ConfNotFound(
|
|
f"No mutation configuration found for {obj.__class__}"
|
|
)
|
|
return conf
|
|
|
|
|
|
class MutationSerializer(serializers.Serializer):
|
|
def apply(self, obj, validated_data):
|
|
raise NotImplementedError()
|
|
|
|
def post_apply(self, obj, validated_data):
|
|
pass
|
|
|
|
def get_previous_state(self, obj, validated_data):
|
|
return
|
|
|
|
def payload_serialize(self, data):
|
|
return data
|
|
|
|
|
|
class UpdateMutationSerializer(serializers.ModelSerializer, MutationSerializer):
|
|
def __init__(self, *args, **kwargs):
|
|
# we force partial mode, because update mutations are partial
|
|
kwargs.setdefault("partial", True)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@transaction.atomic
|
|
def apply(self, obj, validated_data):
|
|
r = self.update(obj, validated_data)
|
|
self.post_apply(r, validated_data)
|
|
return r
|
|
|
|
def validate(self, validated_data):
|
|
if not validated_data:
|
|
raise serializers.ValidationError("You must update at least one field")
|
|
|
|
return super().validate(validated_data)
|
|
|
|
def db_serialize(self, validated_data):
|
|
serialized_relations = self.get_serialized_relations()
|
|
data = {}
|
|
# ensure model fields are serialized properly
|
|
for key, value in list(validated_data.items()):
|
|
if not isinstance(value, models.Model):
|
|
data[key] = value
|
|
continue
|
|
field = serialized_relations[key]
|
|
data[key] = getattr(value, field)
|
|
return data
|
|
|
|
def payload_serialize(self, data):
|
|
data = super().payload_serialize(data)
|
|
# we use our serialized_relations configuration
|
|
# to ensure we store ids instead of model instances in our json
|
|
# payload
|
|
for field, attr in self.get_serialized_relations().items():
|
|
try:
|
|
obj = data[field]
|
|
except KeyError:
|
|
continue
|
|
if obj is None:
|
|
data[field] = None
|
|
else:
|
|
data[field] = getattr(obj, attr)
|
|
return data
|
|
|
|
def create(self, validated_data):
|
|
validated_data = self.db_serialize(validated_data)
|
|
return super().create(validated_data)
|
|
|
|
def get_previous_state(self, obj, validated_data):
|
|
return get_update_previous_state(
|
|
obj,
|
|
*list(validated_data.keys()),
|
|
serialized_relations=self.get_serialized_relations(),
|
|
handlers=self.get_previous_state_handlers(),
|
|
)
|
|
|
|
def get_serialized_relations(self):
|
|
return {}
|
|
|
|
def get_previous_state_handlers(self):
|
|
return {}
|
|
|
|
|
|
def get_update_previous_state(obj, *fields, serialized_relations={}, handlers={}):
|
|
if not fields:
|
|
raise ValueError("You need to provide at least one field")
|
|
|
|
state = {}
|
|
for field in fields:
|
|
if field in handlers:
|
|
state[field] = handlers[field](obj)
|
|
continue
|
|
value = getattr(obj, field)
|
|
if isinstance(value, models.Model):
|
|
# we store the related object id and repr for better UX
|
|
id_field = serialized_relations[field]
|
|
related_value = getattr(value, id_field)
|
|
state[field] = {"value": related_value, "repr": str(value)}
|
|
else:
|
|
state[field] = {"value": value}
|
|
|
|
return state
|
|
|
|
|
|
registry = Registry()
|