Merge branch '432-tags-mutation' into 'develop'

See #432: API for tags

See merge request funkwhale/funkwhale!830
environments/review-front-arti-0habim/deployments/2230
Eliot Berriot 2019-07-18 09:53:42 +02:00
commit 03a470deaf
17 zmienionych plików z 261 dodań i 40 usunięć

Wyświetl plik

@ -716,3 +716,6 @@ ACTOR_KEY_ROTATION_DELAY = env.int("ACTOR_KEY_ROTATION_DELAY", default=3600 * 48
SUBSONIC_DEFAULT_TRANSCODING_FORMAT = (
env("SUBSONIC_DEFAULT_TRANSCODING_FORMAT", default="mp3") or None
)
# extra tags will be ignored
TAGS_MAX_BY_OBJ = env.int("TAGS_MAX_BY_OBJ", default=30)

Wyświetl plik

@ -86,6 +86,7 @@ class MutationSerializer(serializers.Serializer):
class UpdateMutationSerializer(serializers.ModelSerializer, MutationSerializer):
serialized_relations = {}
previous_state_handlers = {}
def __init__(self, *args, **kwargs):
# we force partial mode, because update mutations are partial
@ -139,16 +140,20 @@ class UpdateMutationSerializer(serializers.ModelSerializer, MutationSerializer):
return get_update_previous_state(
obj,
*list(validated_data.keys()),
serialized_relations=self.serialized_relations
serialized_relations=self.serialized_relations,
handlers=self.previous_state_handlers,
)
def get_update_previous_state(obj, *fields, serialized_relations={}):
def get_update_previous_state(obj, *fields, serialized_relations={}, handlers={}):
if not fields:
raise ValueError("You need to provide at least one field")
state = {}
for field in fields:
if field in handlers:
state[field] = handlers[field](obj)
continue
value = getattr(obj, field)
if isinstance(value, models.Model):
# we store the related object id and repr for better UX

Wyświetl plik

@ -9,9 +9,20 @@ from . import models
from . import utils
def filter_tags(queryset, name, value):
non_empty_tags = [v.lower() for v in value if v]
for tag in non_empty_tags:
queryset = queryset.filter(tagged_items__tag__name=tag).distinct()
return queryset
TAG_FILTER = common_filters.MultipleQueryFilter(method=filter_tags)
class ArtistFilter(moderation_filters.HiddenContentFilterSet):
q = fields.SearchFilter(search_fields=["name"])
playable = filters.BooleanFilter(field_name="_", method="filter_playable")
tag = TAG_FILTER
class Meta:
model = models.Artist
@ -29,7 +40,7 @@ class ArtistFilter(moderation_filters.HiddenContentFilterSet):
class TrackFilter(moderation_filters.HiddenContentFilterSet):
q = fields.SearchFilter(search_fields=["title", "album__title", "artist__name"])
playable = filters.BooleanFilter(field_name="_", method="filter_playable")
tag = common_filters.MultipleQueryFilter(method="filter_tags")
tag = TAG_FILTER
id = common_filters.MultipleQueryFilter(coerce=int)
class Meta:
@ -48,12 +59,6 @@ class TrackFilter(moderation_filters.HiddenContentFilterSet):
actor = utils.get_actor_from_request(self.request)
return queryset.playable_by(actor, value)
def filter_tags(self, queryset, name, value):
non_empty_tags = [v.lower() for v in value if v]
for tag in non_empty_tags:
queryset = queryset.filter(tagged_items__tag__name=tag).distinct()
return queryset
class UploadFilter(filters.FilterSet):
library = filters.CharFilter("library__uuid")
@ -101,6 +106,7 @@ class UploadFilter(filters.FilterSet):
class AlbumFilter(moderation_filters.HiddenContentFilterSet):
playable = filters.BooleanFilter(field_name="_", method="filter_playable")
q = fields.SearchFilter(search_fields=["title", "artist__name"])
tag = TAG_FILTER
class Meta:
model = models.Album

Wyświetl plik

@ -2,7 +2,6 @@ import base64
import datetime
import logging
import pendulum
import re
import mutagen._util
import mutagen.oggtheora
@ -12,6 +11,8 @@ import mutagen.flac
from rest_framework import serializers
from rest_framework.compat import Mapping
from funkwhale_api.tags import models as tags_models
logger = logging.getLogger(__name__)
NODEFAULT = object()
# default title used when imported tracks miss the `Album` tag, see #122
@ -491,9 +492,6 @@ class PermissiveDateField(serializers.CharField):
return None
TAG_REGEX = re.compile(r"^((\w+)([\d_]*))$")
def extract_tags_from_genre(string):
tags = []
delimiter = "@@@@@"
@ -511,7 +509,7 @@ def extract_tags_from_genre(string):
if not tag:
continue
final_tag = ""
if not TAG_REGEX.match(tag.replace(" ", "")):
if not tags_models.TAG_REGEX.match(tag.replace(" ", "")):
# the string contains some non words chars ($, €, etc.), right now
# we simply skip such tags
continue

Wyświetl plik

@ -1,5 +1,7 @@
from funkwhale_api.common import mutations
from funkwhale_api.federation import routes
from funkwhale_api.tags import models as tags_models
from funkwhale_api.tags import serializers as tags_serializers
from . import models
@ -12,17 +14,32 @@ def can_approve(obj, actor):
return obj.is_local and actor.user and actor.user.get_permissions()["library"]
class TagMutation(mutations.UpdateMutationSerializer):
tags = tags_serializers.TagsListField()
previous_state_handlers = {
"tags": lambda obj: list(
sorted(obj.tagged_items.values_list("tag__name", flat=True))
)
}
def update(self, instance, validated_data):
tags = validated_data.pop("tags", [])
r = super().update(instance, validated_data)
tags_models.set_tags(instance, *tags)
return r
@mutations.registry.connect(
"update",
models.Track,
perm_checkers={"suggest": can_suggest, "approve": can_approve},
)
class TrackMutationSerializer(mutations.UpdateMutationSerializer):
class TrackMutationSerializer(TagMutation):
serialized_relations = {"license": "code"}
class Meta:
model = models.Track
fields = ["license", "title", "position", "copyright"]
fields = ["license", "title", "position", "copyright", "tags"]
def post_apply(self, obj, validated_data):
routes.outbox.dispatch(
@ -35,10 +52,10 @@ class TrackMutationSerializer(mutations.UpdateMutationSerializer):
models.Artist,
perm_checkers={"suggest": can_suggest, "approve": can_approve},
)
class ArtistMutationSerializer(mutations.UpdateMutationSerializer):
class ArtistMutationSerializer(TagMutation):
class Meta:
model = models.Artist
fields = ["name"]
fields = ["name", "tags"]
def post_apply(self, obj, validated_data):
routes.outbox.dispatch(
@ -51,10 +68,10 @@ class ArtistMutationSerializer(mutations.UpdateMutationSerializer):
models.Album,
perm_checkers={"suggest": can_suggest, "approve": can_approve},
)
class AlbumMutationSerializer(mutations.UpdateMutationSerializer):
class AlbumMutationSerializer(TagMutation):
class Meta:
model = models.Album
fields = ["title", "release_date"]
fields = ["title", "release_date", "tags"]
def post_apply(self, obj, validated_data):
routes.outbox.dispatch(

Wyświetl plik

@ -67,10 +67,24 @@ class ArtistAlbumSerializer(serializers.ModelSerializer):
class ArtistWithAlbumsSerializer(serializers.ModelSerializer):
albums = ArtistAlbumSerializer(many=True, read_only=True)
tags = serializers.SerializerMethodField()
class Meta:
model = models.Artist
fields = ("id", "fid", "mbid", "name", "creation_date", "albums", "is_local")
fields = (
"id",
"fid",
"mbid",
"name",
"creation_date",
"albums",
"is_local",
"tags",
)
def get_tags(self, obj):
tagged_items = getattr(obj, "_prefetched_tagged_items", [])
return [ti.tag.name for ti in tagged_items]
class ArtistSimpleSerializer(serializers.ModelSerializer):
@ -124,6 +138,7 @@ class AlbumSerializer(serializers.ModelSerializer):
artist = ArtistSimpleSerializer(read_only=True)
cover = cover_field
is_playable = serializers.SerializerMethodField()
tags = serializers.SerializerMethodField()
class Meta:
model = models.Album
@ -139,6 +154,7 @@ class AlbumSerializer(serializers.ModelSerializer):
"creation_date",
"is_playable",
"is_local",
"tags",
)
def get_tracks(self, o):
@ -153,6 +169,10 @@ class AlbumSerializer(serializers.ModelSerializer):
except AttributeError:
return None
def get_tags(self, obj):
tagged_items = getattr(obj, "_prefetched_tagged_items", [])
return [ti.tag.name for ti in tagged_items]
class TrackAlbumSerializer(serializers.ModelSerializer):
artist = ArtistSimpleSerializer(read_only=True)
@ -192,6 +212,7 @@ class TrackSerializer(serializers.ModelSerializer):
album = TrackAlbumSerializer(read_only=True)
uploads = serializers.SerializerMethodField()
listen_url = serializers.SerializerMethodField()
tags = serializers.SerializerMethodField()
class Meta:
model = models.Track
@ -210,6 +231,7 @@ class TrackSerializer(serializers.ModelSerializer):
"copyright",
"license",
"is_local",
"tags",
)
def get_listen_url(self, obj):
@ -219,6 +241,10 @@ class TrackSerializer(serializers.ModelSerializer):
uploads = getattr(obj, "playable_uploads", [])
return TrackUploadSerializer(uploads, many=True).data
def get_tags(self, obj):
tagged_items = getattr(obj, "_prefetched_tagged_items", [])
return [ti.tag.name for ti in tagged_items]
@common_serializers.track_fields_for_update("name", "description", "privacy_level")
class LibraryForOwnerSerializer(serializers.ModelSerializer):

Wyświetl plik

@ -23,13 +23,19 @@ from funkwhale_api.federation import actors
from funkwhale_api.federation import api_serializers as federation_api_serializers
from funkwhale_api.federation import decorators as federation_decorators
from funkwhale_api.federation import routes
from funkwhale_api.tags.models import Tag
from funkwhale_api.tags.models import Tag, TaggedItem
from funkwhale_api.users.oauth import permissions as oauth_permissions
from . import filters, licenses, models, serializers, tasks, utils
logger = logging.getLogger(__name__)
TAG_PREFETCH = Prefetch(
"tagged_items",
queryset=TaggedItem.objects.all().select_related().order_by("tag__name"),
to_attr="_prefetched_tagged_items",
)
def get_libraries(filter_uploads):
def libraries(self, request, *args, **kwargs):
@ -71,7 +77,9 @@ class ArtistViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelV
albums = albums.annotate_playable_by_actor(
utils.get_actor_from_request(self.request)
)
return queryset.prefetch_related(Prefetch("albums", queryset=albums))
return queryset.prefetch_related(
Prefetch("albums", queryset=albums), TAG_PREFETCH
)
libraries = action(methods=["get"], detail=True)(
get_libraries(
@ -103,7 +111,9 @@ class AlbumViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelVi
.with_playable_uploads(utils.get_actor_from_request(self.request))
.order_for_album()
)
qs = queryset.prefetch_related(Prefetch("tracks", queryset=tracks))
qs = queryset.prefetch_related(
Prefetch("tracks", queryset=tracks), TAG_PREFETCH
)
return qs
libraries = action(methods=["get"], detail=True)(
@ -206,7 +216,7 @@ class TrackViewSet(common_views.SkipFilterForGetObject, viewsets.ReadOnlyModelVi
queryset = queryset.with_playable_uploads(
utils.get_actor_from_request(self.request)
)
return queryset
return queryset.prefetch_related(TAG_PREFETCH)
libraries = action(methods=["get"], detail=True)(
get_libraries(filter_uploads=lambda o, uploads: uploads.filter(track=o))

Wyświetl plik

@ -1,3 +1,6 @@
import re
from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.fields import CICharField
@ -8,6 +11,9 @@ from django.utils import timezone
from django.utils.translation import gettext_lazy as _
TAG_REGEX = re.compile(r"^((\w+)([\d_]*))$")
class Tag(models.Model):
name = CICharField(max_length=100, unique=True)
creation_date = models.DateTimeField(default=timezone.now)
@ -60,6 +66,9 @@ def add_tags(obj, *tags):
@transaction.atomic
def set_tags(obj, *tags):
# we ignore any extra tags if the length of the list is higher
# than our accepted size
tags = tags[: settings.TAGS_MAX_BY_OBJ]
tags = set(tags)
existing = set(
TaggedItem.objects.for_content_object(obj).values_list("tag__name", flat=True)

Wyświetl plik

@ -1,5 +1,7 @@
from rest_framework import serializers
from django.conf import settings
from . import models
@ -7,3 +9,26 @@ class TagSerializer(serializers.ModelSerializer):
class Meta:
model = models.Tag
fields = ["name", "creation_date"]
class TagNameField(serializers.CharField):
def to_internal_value(self, value):
value = super().to_internal_value(value)
if not models.TAG_REGEX.match(value):
raise serializers.ValidationError('Invalid tag "{}"'.format(value))
return value
class TagsListField(serializers.ListField):
def __init__(self, *args, **kwargs):
kwargs.setdefault("min_length", 0)
kwargs.setdefault("child", TagNameField())
super().__init__(*args, **kwargs)
def to_internal_value(self, value):
value = super().to_internal_value(value)
if not value:
return value
# we ignore any extra tags if the length of the list is higher
# than our accepted size
return value[: settings.TAGS_MAX_BY_OBJ]

Wyświetl plik

@ -51,7 +51,7 @@ def test_apply_update_mutation(factories, mutations_registry, mocker):
)
assert previous_state == get_update_previous_state.return_value
get_update_previous_state.assert_called_once_with(
user, "username", serialized_relations={}
user, "username", serialized_relations={}, handlers={}
)
user.refresh_from_db()

Wyświetl plik

@ -1,3 +1,5 @@
import pytest
from funkwhale_api.music import filters
from funkwhale_api.music import models
@ -54,28 +56,54 @@ def test_artist_filter_track_album_artist(factories, mocker, queryset_equal_list
assert filterset.qs == [hidden_track]
@pytest.mark.parametrize(
"factory_name, filterset_class",
[
("music.Track", filters.TrackFilter),
("music.Artist", filters.TrackFilter),
("music.Album", filters.TrackFilter),
],
)
def test_track_filter_tag_single(
factories, mocker, queryset_equal_list, anonymous_user
factory_name,
filterset_class,
factories,
mocker,
queryset_equal_list,
anonymous_user,
):
factories["music.Track"]()
factories[factory_name]()
# tag name partially match the query, so this shouldn't match
factories["music.Track"](set_tags=["TestTag1"])
tagged = factories["music.Track"](set_tags=["TestTag"])
qs = models.Track.objects.all()
filterset = filters.TrackFilter(
factories[factory_name](set_tags=["TestTag1"])
tagged = factories[factory_name](set_tags=["TestTag"])
qs = tagged.__class__.objects.all()
filterset = filterset_class(
{"tag": "testTaG"}, request=mocker.Mock(user=anonymous_user), queryset=qs
)
assert filterset.qs == [tagged]
@pytest.mark.parametrize(
"factory_name, filterset_class",
[
("music.Track", filters.TrackFilter),
("music.Artist", filters.ArtistFilter),
("music.Album", filters.AlbumFilter),
],
)
def test_track_filter_tag_multiple(
factories, mocker, queryset_equal_list, anonymous_user
factory_name,
filterset_class,
factories,
mocker,
queryset_equal_list,
anonymous_user,
):
factories["music.Track"](set_tags=["TestTag1"])
tagged = factories["music.Track"](set_tags=["TestTag1", "TestTag2"])
qs = models.Track.objects.all()
filterset = filters.TrackFilter(
factories[factory_name](set_tags=["TestTag1"])
tagged = factories[factory_name](set_tags=["TestTag1", "TestTag2"])
qs = tagged.__class__.objects.all()
filterset = filterset_class(
{"tag": ["testTaG1", "TestTag2"]},
request=mocker.Mock(user=anonymous_user),
queryset=qs,

Wyświetl plik

@ -2,6 +2,7 @@ import datetime
import pytest
from funkwhale_api.music import licenses
from funkwhale_api.tags import models as tags_models
@pytest.mark.parametrize(
@ -117,3 +118,25 @@ def test_track_mutation_apply_outbox(factories, mocker):
dispatch.assert_called_once_with(
{"type": "Update", "object": {"type": "Track"}}, context={"track": track}
)
@pytest.mark.parametrize("factory_name", ["music.Artist", "music.Album", "music.Track"])
def test_mutation_set_tags(factory_name, factories, now, mocker):
tags = ["tag1", "tag2"]
dispatch = mocker.patch("funkwhale_api.federation.routes.outbox.dispatch")
set_tags = mocker.spy(tags_models, "set_tags")
obj = factories[factory_name]()
assert obj.tagged_items.all().count() == 0
mutation = factories["common.Mutation"](
type="update", target=obj, payload={"tags": tags}
)
mutation.apply()
obj.refresh_from_db()
assert sorted(obj.tagged_items.all().values_list("tag__name", flat=True)) == tags
set_tags.assert_called_once_with(obj, *tags)
obj_type = factory_name.lstrip("music.")
dispatch.assert_called_once_with(
{"type": "Update", "object": {"type": obj_type}},
context={obj_type.lower(): obj},
)

Wyświetl plik

@ -69,6 +69,7 @@ def test_artist_with_albums_serializer(factories, to_api_date):
"is_local": artist.is_local,
"creation_date": to_api_date(artist.creation_date),
"albums": [serializers.ArtistAlbumSerializer(album).data],
"tags": [],
}
serializer = serializers.ArtistWithAlbumsSerializer(artist)
assert serializer.data == expected
@ -175,6 +176,7 @@ def test_album_serializer(factories, to_api_date):
"release_date": to_api_date(album.release_date),
"tracks": serializers.AlbumTrackSerializer([track2, track1], many=True).data,
"is_local": album.is_local,
"tags": [],
}
serializer = serializers.AlbumSerializer(album)
@ -202,6 +204,7 @@ def test_track_serializer(factories, to_api_date):
"license": upload.track.license.code,
"copyright": upload.track.copyright,
"is_local": upload.track.is_local,
"tags": [],
}
serializer = serializers.TrackSerializer(track)
assert serializer.data == expected

Wyświetl plik

@ -16,8 +16,11 @@ DATA_DIR = os.path.dirname(os.path.abspath(__file__))
def test_artist_list_serializer(api_request, factories, logged_in_api_client):
tags = ["tag1", "tag2"]
track = factories["music.Upload"](
library__privacy_level="everyone", import_status="finished"
library__privacy_level="everyone",
import_status="finished",
track__album__artist__set_tags=tags,
).track
artist = track.artist
request = api_request.get("/")
@ -27,8 +30,10 @@ def test_artist_list_serializer(api_request, factories, logged_in_api_client):
)
expected = {"count": 1, "next": None, "previous": None, "results": serializer.data}
for artist in serializer.data:
artist["tags"] = tags
for album in artist["albums"]:
album["is_playable"] = True
url = reverse("api:v1:artists-list")
response = logged_in_api_client.get(url)
@ -37,8 +42,11 @@ def test_artist_list_serializer(api_request, factories, logged_in_api_client):
def test_album_list_serializer(api_request, factories, logged_in_api_client):
tags = ["tag1", "tag2"]
track = factories["music.Upload"](
library__privacy_level="everyone", import_status="finished"
library__privacy_level="everyone",
import_status="finished",
track__album__set_tags=tags,
).track
album = track.album
request = api_request.get("/")
@ -47,6 +55,8 @@ def test_album_list_serializer(api_request, factories, logged_in_api_client):
qs, many=True, context={"request": request}
)
expected = {"count": 1, "next": None, "previous": None, "results": serializer.data}
for album in serializer.data:
album["tags"] = tags
url = reverse("api:v1:albums-list")
response = logged_in_api_client.get(url)
@ -55,8 +65,11 @@ def test_album_list_serializer(api_request, factories, logged_in_api_client):
def test_track_list_serializer(api_request, factories, logged_in_api_client):
tags = ["tag1", "tag2"]
track = factories["music.Upload"](
library__privacy_level="everyone", import_status="finished"
library__privacy_level="everyone",
import_status="finished",
track__set_tags=tags,
).track
request = api_request.get("/")
qs = track.__class__.objects.with_playable_uploads(None)
@ -64,6 +77,8 @@ def test_track_list_serializer(api_request, factories, logged_in_api_client):
qs, many=True, context={"request": request}
)
expected = {"count": 1, "next": None, "previous": None, "results": serializer.data}
for track in serializer.data:
track["tags"] = tags
url = reverse("api:v1:tracks-list")
response = logged_in_api_client.get(url)

Wyświetl plik

@ -53,6 +53,24 @@ def test_set_tags(factories, existing, given, expected):
assert match.content_object == obj
@pytest.mark.parametrize(
"max, tags, expected",
[
(5, ["hello", "world"], ["hello", "world"]),
# we truncate extra tags
(1, ["hello", "world"], ["hello"]),
(2, ["hello", "world", "foo"], ["hello", "world"]),
],
)
def test_set_tags_honor_TAGS_MAX_BY_OBJ(factories, max, tags, expected, settings):
settings.TAGS_MAX_BY_OBJ = max
obj = factories["music.Artist"]()
models.set_tags(obj, *tags)
assert sorted(obj.tagged_items.values_list("tag__name", flat=True)) == expected
@pytest.mark.parametrize("factory_name", ["music.Track", "music.Album", "music.Artist"])
def test_models_that_support_tags(factories, factory_name):
tags = ["tag1", "tag2"]

Wyświetl plik

@ -1,3 +1,5 @@
import pytest
from funkwhale_api.tags import serializers
@ -12,3 +14,33 @@ def test_tag_serializer(factories):
}
assert serializer.data == expected
@pytest.mark.parametrize(
"name",
[
"",
"invalid because spaces",
"invalid-because-dashes",
"invalidbecausenonbreakingspaces",
],
)
def test_tag_name_field_validation(name):
field = serializers.TagNameField()
with pytest.raises(serializers.serializers.ValidationError):
field.to_internal_value(name)
@pytest.mark.parametrize(
"max, tags, expected",
[
(5, ["hello", "world"], ["hello", "world"]),
# we truncate extra tags
(1, ["hello", "world"], ["hello"]),
(2, ["hello", "world", "foo"], ["hello", "world"]),
],
)
def test_tags_list_field_honor_TAGS_MAX_BY_OBJ(max, tags, expected, settings):
settings.TAGS_MAX_BY_OBJ = max
field = serializers.TagsListField()
assert field.to_internal_value(tags) == expected

Wyświetl plik

@ -75,6 +75,9 @@ http {
location /front-server/ {
proxy_pass http://funkwhale-front/;
}
location /sockjs-node/ {
proxy_pass http://funkwhale-front/sockjs-node/;
}
location / {
include /etc/nginx/funkwhale_proxy.conf;