diff --git a/api/funkwhale_api/common/models.py b/api/funkwhale_api/common/models.py index 1b9cc1e57..c277fb9df 100644 --- a/api/funkwhale_api/common/models.py +++ b/api/funkwhale_api/common/models.py @@ -3,6 +3,7 @@ import uuid from django.contrib.postgres.fields import JSONField from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType +from django.conf import settings from django.db import models, transaction from django.utils import timezone from django.urls import reverse @@ -10,6 +11,18 @@ from django.urls import reverse from funkwhale_api.federation import utils as federation_utils +class LocalFromFidQuerySet: + def local(self, include=True): + host = settings.FEDERATION_HOSTNAME + query = models.Q(fid__startswith="http://{}/".format(host)) | models.Q( + fid__startswith="https://{}/".format(host) + ) + if include: + return self.filter(query) + else: + return self.filter(~query) + + class MutationQuerySet(models.QuerySet): def get_for_target(self, target): content_type = ContentType.objects.get_for_model(target) diff --git a/api/funkwhale_api/federation/views.py b/api/funkwhale_api/federation/views.py index 13791ec21..665a08b17 100644 --- a/api/funkwhale_api/federation/views.py +++ b/api/funkwhale_api/federation/views.py @@ -7,6 +7,7 @@ from rest_framework.decorators import action from funkwhale_api.common import preferences from funkwhale_api.music import models as music_models +from funkwhale_api.music import utils as music_utils from . import activity, authentication, models, renderers, serializers, utils, webfinger @@ -202,9 +203,17 @@ class MusicUploadViewSet( authentication_classes = [authentication.SignatureAuthentication] permission_classes = [] renderer_classes = [renderers.ActivityPubRenderer] - queryset = music_models.Upload.objects.none() + queryset = music_models.Upload.objects.local().select_related( + "library__actor", "track__artist", "track__album__artist" + ) + serializer_class = serializers.UploadSerializer lookup_field = "uuid" + def get_queryset(self): + queryset = super().get_queryset() + actor = music_utils.get_actor_from_request(self.request) + return queryset.playable_by(actor) + class MusicArtistViewSet( FederationMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet @@ -212,7 +221,8 @@ class MusicArtistViewSet( authentication_classes = [authentication.SignatureAuthentication] permission_classes = [] renderer_classes = [renderers.ActivityPubRenderer] - queryset = music_models.Artist.objects.none() + queryset = music_models.Artist.objects.local() + serializer_class = serializers.ArtistSerializer lookup_field = "uuid" @@ -222,7 +232,8 @@ class MusicAlbumViewSet( authentication_classes = [authentication.SignatureAuthentication] permission_classes = [] renderer_classes = [renderers.ActivityPubRenderer] - queryset = music_models.Album.objects.none() + queryset = music_models.Album.objects.local().select_related("artist") + serializer_class = serializers.AlbumSerializer lookup_field = "uuid" @@ -232,5 +243,8 @@ class MusicTrackViewSet( authentication_classes = [authentication.SignatureAuthentication] permission_classes = [] renderer_classes = [renderers.ActivityPubRenderer] - queryset = music_models.Track.objects.none() + queryset = music_models.Track.objects.local().select_related( + "album__artist", "artist" + ) + serializer_class = serializers.TrackSerializer lookup_field = "uuid" diff --git a/api/funkwhale_api/music/models.py b/api/funkwhale_api/music/models.py index 4ba832717..e40483c80 100644 --- a/api/funkwhale_api/music/models.py +++ b/api/funkwhale_api/music/models.py @@ -24,6 +24,7 @@ from versatileimagefield.image_warmer import VersatileImageFieldWarmer from funkwhale_api import musicbrainz from funkwhale_api.common import fields +from funkwhale_api.common import models as common_models from funkwhale_api.common import session from funkwhale_api.common import utils as common_utils from funkwhale_api.federation import models as federation_models @@ -141,7 +142,7 @@ class License(models.Model): logger.warning("%s do not match any registered license", self.code) -class ArtistQuerySet(models.QuerySet): +class ArtistQuerySet(common_models.LocalFromFidQuerySet, models.QuerySet): def with_albums_count(self): return self.annotate(_albums_count=models.Count("albums")) @@ -215,7 +216,7 @@ def import_tracks(instance, cleaned_data, raw_data): importers.load(Track, track_cleaned_data, track_data, Track.import_hooks) -class AlbumQuerySet(models.QuerySet): +class AlbumQuerySet(common_models.LocalFromFidQuerySet, models.QuerySet): def with_tracks_count(self): return self.annotate(_tracks_count=models.Count("tracks")) @@ -416,7 +417,7 @@ class Lyrics(models.Model): ) -class TrackQuerySet(models.QuerySet): +class TrackQuerySet(common_models.LocalFromFidQuerySet, models.QuerySet): def for_nested_serialization(self): return self.select_related().select_related("album__artist", "artist") diff --git a/api/tests/federation/test_views.py b/api/tests/federation/test_views.py index 282ee16fe..a7d64366b 100644 --- a/api/tests/federation/test_views.py +++ b/api/tests/federation/test_views.py @@ -174,3 +174,75 @@ def test_music_library_retrieve_page_follow( response = api_client.get(url, {"page": 1}) assert response.status_code == expected + + +@pytest.mark.parametrize( + "factory, serializer_class, namespace", + [ + ("music.Artist", serializers.ArtistSerializer, "artists"), + ("music.Album", serializers.AlbumSerializer, "albums"), + ("music.Track", serializers.TrackSerializer, "tracks"), + ], +) +def test_music_local_entity_detail( + factories, api_client, factory, serializer_class, namespace, settings +): + obj = factories[factory](fid="http://{}/1".format(settings.FEDERATION_HOSTNAME)) + url = reverse( + "federation:music:{}-detail".format(namespace), kwargs={"uuid": obj.uuid} + ) + response = api_client.get(url) + + assert response.status_code == 200 + assert response.data == serializer_class(obj).data + + +@pytest.mark.parametrize( + "factory, namespace", + [("music.Artist", "artists"), ("music.Album", "albums"), ("music.Track", "tracks")], +) +def test_music_non_local_entity_detail( + factories, api_client, factory, namespace, settings +): + obj = factories[factory](fid="http://wrong-domain/1") + url = reverse( + "federation:music:{}-detail".format(namespace), kwargs={"uuid": obj.uuid} + ) + response = api_client.get(url) + + assert response.status_code == 404 + + +@pytest.mark.parametrize( + "privacy_level, expected", [("me", 404), ("instance", 404), ("everyone", 200)] +) +def test_music_upload_detail(factories, api_client, privacy_level, expected): + upload = factories["music.Upload"]( + library__privacy_level=privacy_level, + library__actor__local=True, + import_status="finished", + ) + url = reverse("federation:music:uploads-detail", kwargs={"uuid": upload.uuid}) + response = api_client.get(url) + + assert response.status_code == expected + if expected == 200: + assert response.data == serializers.UploadSerializer(upload).data + + +@pytest.mark.parametrize("privacy_level", ["me", "instance"]) +def test_music_upload_detail_private_approved_follow( + factories, api_client, authenticated_actor, privacy_level +): + upload = factories["music.Upload"]( + library__privacy_level=privacy_level, + library__actor__local=True, + import_status="finished", + ) + factories["federation.LibraryFollow"]( + actor=authenticated_actor, target=upload.library, approved=True + ) + url = reverse("federation:music:uploads-detail", kwargs={"uuid": upload.uuid}) + response = api_client.get(url) + + assert response.status_code == 200 diff --git a/api/tests/music/test_models.py b/api/tests/music/test_models.py index ab32579b7..cc73e85a3 100644 --- a/api/tests/music/test_models.py +++ b/api/tests/music/test_models.py @@ -522,3 +522,14 @@ def test_track_order_for_album(factories): t4 = factories["music.Track"](album=album, position=2, disc_number=2) assert list(models.Track.objects.order_for_album()) == [t1, t3, t2, t4] + + +@pytest.mark.parametrize("factory", ["music.Artist", "music.Album", "music.Track"]) +def test_queryset_local_entities(factories, settings, factory): + settings.FEDERATION_HOSTNAME = "test.com" + obj1 = factories[factory](fid="http://test.com/1") + obj2 = factories[factory](fid="https://test.com/2") + factories[factory](fid="https://test.coma/3") + factories[factory](fid="https://noope/3") + + assert list(obj1.__class__.objects.local().order_by("id")) == [obj1, obj2]