diff --git a/api/funkwhale_api/audio/serializers.py b/api/funkwhale_api/audio/serializers.py index a946df9a9..205ec383c 100644 --- a/api/funkwhale_api/audio/serializers.py +++ b/api/funkwhale_api/audio/serializers.py @@ -97,6 +97,15 @@ class ChannelSerializer(serializers.ModelSerializer): def get_artist(self, obj): return music_serializers.serialize_artist_simple(obj.artist) + def to_representation(self, obj): + data = super().to_representation(obj) + if self.context.get("subscriptions_count"): + data["subscriptions_count"] = self.get_subscriptions_count(obj) + return data + + def get_subscriptions_count(self, obj): + return obj.actor.received_follows.exclude(approved=False).count() + class SubscriptionSerializer(serializers.Serializer): approved = serializers.BooleanField(read_only=True) diff --git a/api/funkwhale_api/audio/views.py b/api/funkwhale_api/audio/views.py index ba9983672..5162730a6 100644 --- a/api/funkwhale_api/audio/views.py +++ b/api/funkwhale_api/audio/views.py @@ -92,6 +92,11 @@ class ChannelViewSet( request.user.actor.emitted_follows.filter(target=object.actor).delete() return response.Response(status=204) + def get_serializer_context(self): + context = super().get_serializer_context() + context["subscriptions_count"] = self.action in ["retrieve", "create", "update"] + return context + class SubscriptionsViewSet( ChannelsMixin, diff --git a/api/tests/audio/test_serializers.py b/api/tests/audio/test_serializers.py index b431e8e96..7f39bb338 100644 --- a/api/tests/audio/test_serializers.py +++ b/api/tests/audio/test_serializers.py @@ -90,6 +90,16 @@ def test_channel_serializer_representation(factories, to_api_date): assert serializers.ChannelSerializer(channel).data == expected +def test_channel_serializer_representation_subscriptions_count(factories, to_api_date): + channel = factories["audio.Channel"]() + factories["federation.Follow"](target=channel.actor) + factories["federation.Follow"](target=channel.actor, approved=False) + serializer = serializers.ChannelSerializer( + channel, context={"subscriptions_count": True} + ) + assert serializer.data["subscriptions_count"] == 1 + + def test_subscription_serializer(factories, to_api_date): subscription = factories["audio.Subscription"]() expected = { diff --git a/api/tests/audio/test_views.py b/api/tests/audio/test_views.py index 0724bfa98..935ee4342 100644 --- a/api/tests/audio/test_views.py +++ b/api/tests/audio/test_views.py @@ -41,7 +41,9 @@ def test_channel_create(logged_in_api_client): def test_channel_detail(factories, logged_in_api_client): channel = factories["audio.Channel"](artist__description=None) url = reverse("api:v1:channels-detail", kwargs={"uuid": channel.uuid}) - expected = serializers.ChannelSerializer(channel).data + expected = serializers.ChannelSerializer( + channel, context={"subscriptions_count": True} + ).data response = logged_in_api_client.get(url) assert response.status_code == 200