diff --git a/api/config/urls/api_v2.py b/api/config/urls/api_v2.py index 95c776a0c..d5e040337 100644 --- a/api/config/urls/api_v2.py +++ b/api/config/urls/api_v2.py @@ -10,6 +10,10 @@ v2_patterns += [ r"^instance/", include(("funkwhale_api.instance.urls", "instance"), namespace="instance"), ), + url( + r"^radios/", + include(("funkwhale_api.radios.urls_v2", "radios"), namespace="radios"), + ), ] urlpatterns = [url("", include((v2_patterns, "v2"), namespace="v2"))] diff --git a/api/funkwhale_api/radios/models.py b/api/funkwhale_api/radios/models.py index 9d8753608..fdd6cd481 100644 --- a/api/funkwhale_api/radios/models.py +++ b/api/funkwhale_api/radios/models.py @@ -54,10 +54,6 @@ class RadioSession(models.Model): CONFIG_VERSION = 0 config = JSONField(encoder=DjangoJSONEncoder, blank=True, null=True) - def save(self, **kwargs): - self.radio.clean(self) - super().save(**kwargs) - @property def next_position(self): next_position = 1 @@ -68,16 +64,24 @@ class RadioSession(models.Model): return next_position - def add(self, track): - new_session_track = RadioSessionTrack.objects.create( - track=track, session=self, position=self.next_position - ) + def add(self, tracks): + next_position = self.next_position + radio_session_tracks = [] + for i, track in enumerate(tracks): + radio_session_track = RadioSessionTrack( + track=track, session=self, position=next_position + i + ) + radio_session_tracks.append(radio_session_track) - return new_session_track + new_session_tracks = RadioSessionTrack.objects.bulk_create(radio_session_tracks) - @property - def radio(self): - from .registries import registry + return new_session_tracks + + def radio(self, api_version): + if api_version == 2: + from .registries_v2 import registry + else: + from .registries import registry return registry[self.radio_type](session=self) diff --git a/api/funkwhale_api/radios/radios.py b/api/funkwhale_api/radios/radios.py index 821a181cf..d80c01910 100644 --- a/api/funkwhale_api/radios/radios.py +++ b/api/funkwhale_api/radios/radios.py @@ -13,10 +13,9 @@ from funkwhale_api.federation import fields as federation_fields from funkwhale_api.federation import models as federation_models from funkwhale_api.moderation import filters as moderation_filters from funkwhale_api.music.models import Artist, Library, Track, Upload -from funkwhale_api.radios import lb_recommendations from funkwhale_api.tags.models import Tag -from . import filters, models +from . import filters, lb_recommendations, models from .registries import registry logger = logging.getLogger(__name__) @@ -63,11 +62,19 @@ class SessionRadio(SimpleRadio): return self.session def get_queryset(self, **kwargs): - qs = Track.objects.all() - if not self.session: - return qs - if not self.session.user: - return qs + if not self.session or not self.session.user: + return ( + Track.objects.all() + .with_playable_uploads(actor=None) + .select_related("artist", "album__artist", "attributed_to") + ) + else: + qs = ( + Track.objects.all() + .with_playable_uploads(self.session.user.actor) + .select_related("artist", "album__artist", "attributed_to") + ) + query = moderation_filters.get_filtered_content_query( config=moderation_filters.USER_FILTER_CONFIG["TRACK"], user=self.session.user, @@ -77,6 +84,16 @@ class SessionRadio(SimpleRadio): def get_queryset_kwargs(self): return {} + def filter_queryset(self, queryset): + return queryset + + def filter_from_session(self, queryset): + already_played = self.session.session_tracks.all().values_list( + "track", flat=True + ) + queryset = queryset.exclude(pk__in=already_played) + return queryset + def get_choices(self, **kwargs): kwargs.update(self.get_queryset_kwargs()) queryset = self.get_queryset(**kwargs) @@ -89,16 +106,6 @@ class SessionRadio(SimpleRadio): queryset = self.filter_queryset(queryset) return queryset - def filter_queryset(self, queryset): - return queryset - - def filter_from_session(self, queryset): - already_played = self.session.session_tracks.all().values_list( - "track", flat=True - ) - queryset = queryset.exclude(pk__in=already_played) - return queryset - def pick(self, **kwargs): return self.pick_many(quantity=1, **kwargs)[0] @@ -106,8 +113,7 @@ class SessionRadio(SimpleRadio): choices = self.get_choices(**kwargs) picked_choices = super().pick_many(choices=choices, quantity=quantity) if self.session: - for choice in picked_choices: - self.session.add(choice) + self.session.add(picked_choices) return picked_choices def validate_session(self, data, **context): @@ -191,7 +197,9 @@ class CustomMultiple(SessionRadio): def validate_session(self, data, **context): data = super().validate_session(data, **context) - if data.get("config") is None: + try: + data["config"] is not None + except KeyError: raise serializers.ValidationError( "You must provide a configuration for this radio" ) diff --git a/api/funkwhale_api/radios/radios_v2.py b/api/funkwhale_api/radios/radios_v2.py new file mode 100644 index 000000000..dc7290980 --- /dev/null +++ b/api/funkwhale_api/radios/radios_v2.py @@ -0,0 +1,510 @@ +import datetime +import json +import logging +import pickle +import random +from typing import List, Optional, Tuple + +from django.core.cache import cache +from django.core.exceptions import ValidationError +from django.db import connection +from django.db.models import Q +from rest_framework import serializers + +from funkwhale_api.federation import fields as federation_fields +from funkwhale_api.federation import models as federation_models +from funkwhale_api.moderation import filters as moderation_filters +from funkwhale_api.music.models import Artist, Library, Track, Upload +from funkwhale_api.tags.models import Tag + +from . import filters, lb_recommendations, models +from .registries_v2 import registry + +logger = logging.getLogger(__name__) + + +class SimpleRadio: + related_object_field = None + + def clean(self, instance): + return + + def weighted_pick( + self, + choices: List[Tuple[int, int]], + previous_choices: Optional[List[int]] = None, + ) -> int: + total = sum(weight for c, weight in choices) + r = random.uniform(0, total) + upto = 0 + for choice, weight in choices: + if upto + weight >= r: + return choice + upto += weight + + +class SessionRadio(SimpleRadio): + def __init__(self, session=None): + self.session = session + + def start_session(self, user, **kwargs): + self.session = models.RadioSession.objects.create( + user=user, radio_type=self.radio_type, **kwargs + ) + return self.session + + def get_queryset(self, **kwargs): + actor = None + try: + actor = self.session.user.actor + except KeyError: + pass # Maybe logging would be helpful + + qs = ( + Track.objects.all() + .with_playable_uploads(actor=actor) + .select_related("artist", "album__artist", "attributed_to") + ) + + query = moderation_filters.get_filtered_content_query( + config=moderation_filters.USER_FILTER_CONFIG["TRACK"], + user=self.session.user, + ) + return qs.exclude(query) + + def get_queryset_kwargs(self): + return {} + + def filter_queryset(self, queryset): + return queryset + + def filter_from_session(self, queryset): + already_played = self.session.session_tracks.all().values_list( + "track", flat=True + ) + queryset = queryset.exclude(pk__in=already_played) + return queryset + + def cache_batch_radio_track(self, **kwargs): + BATCH_SIZE = 100 + # get cached RadioTracks if any + try: + cached_evaluated_radio_tracks = pickle.loads( + cache.get(f"radiotracks{self.session.id}") + ) + except TypeError: + cached_evaluated_radio_tracks = None + + # get the queryset and apply filters + kwargs.update(self.get_queryset_kwargs()) + queryset = self.get_queryset(**kwargs) + queryset = self.filter_from_session(queryset) + + if kwargs["filter_playable"] is True: + queryset = queryset.playable_by( + self.session.user.actor if self.session.user else None + ) + queryset = self.filter_queryset(queryset) + + # select a random batch of the qs + sliced_queryset = queryset.order_by("?")[:BATCH_SIZE] + if len(sliced_queryset) <= 0 and not cached_evaluated_radio_tracks: + raise ValueError("No more radio candidates") + + # create the radio session tracks into db in bulk + self.session.add(sliced_queryset) + + # evaluate the queryset to save it in cache + radio_tracks = list(sliced_queryset) + + if cached_evaluated_radio_tracks is not None: + radio_tracks.extend(cached_evaluated_radio_tracks) + logger.info( + f"Setting redis cache for radio generation with radio id {self.session.id}" + ) + cache.set(f"radiotracks{self.session.id}", pickle.dumps(radio_tracks), 3600) + cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600) + + return sliced_queryset + + def get_choices(self, quantity, **kwargs): + if cache.get(f"radiotracks{self.session.id}"): + cached_radio_tracks = pickle.loads( + cache.get(f"radiotracks{self.session.id}") + ) + logger.info("Using redis cache for radio generation") + radio_tracks = cached_radio_tracks + if len(radio_tracks) < quantity: + logger.info( + "Not enough radio tracks in cache. Trying to generate new cache" + ) + sliced_queryset = self.cache_batch_radio_track(**kwargs) + sliced_queryset = cache.get(f"radioqueryset{self.session.id}") + else: + sliced_queryset = self.cache_batch_radio_track(**kwargs) + + return sliced_queryset[:quantity] + + def pick_many(self, quantity, **kwargs): + if self.session: + sliced_queryset = self.get_choices(quantity=quantity, **kwargs) + else: + logger.info( + "No radio session. Can't track user playback. Won't cache queryset results" + ) + sliced_queryset = self.get_choices(quantity=quantity, **kwargs) + + return sliced_queryset + + def validate_session(self, data, **context): + return data + + +@registry.register(name="random") +class RandomRadio(SessionRadio): + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + return qs.filter(artist__content_category="music").order_by("?") + + +@registry.register(name="random_library") +class RandomLibraryRadio(SessionRadio): + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + tracks_ids = self.session.user.actor.attributed_tracks.all().values_list( + "id", flat=True + ) + query = Q(artist__content_category="music") & Q(pk__in=tracks_ids) + return qs.filter(query).order_by("?") + + +@registry.register(name="favorites") +class FavoritesRadio(SessionRadio): + def get_queryset_kwargs(self): + kwargs = super().get_queryset_kwargs() + if self.session: + kwargs["user"] = self.session.user + return kwargs + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + track_ids = kwargs["user"].track_favorites.all().values_list("track", flat=True) + return qs.filter(pk__in=track_ids, artist__content_category="music") + + +@registry.register(name="custom") +class CustomRadio(SessionRadio): + def get_queryset_kwargs(self): + kwargs = super().get_queryset_kwargs() + kwargs["user"] = self.session.user + kwargs["custom_radio"] = self.session.custom_radio + return kwargs + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + return filters.run(kwargs["custom_radio"].config, candidates=qs) + + def validate_session(self, data, **context): + data = super().validate_session(data, **context) + try: + user = data["user"] + except KeyError: + user = context.get("user") + try: + assert data["custom_radio"].user == user or data["custom_radio"].is_public + except KeyError: + raise serializers.ValidationError("You must provide a custom radio") + except AssertionError: + raise serializers.ValidationError("You don't have access to this radio") + return data + + +@registry.register(name="custom_multiple") +class CustomMultiple(SessionRadio): + """ + Receive a vuejs generated config and use it to launch a radio session + """ + + config = serializers.JSONField(required=True) + + def get_config(self, data): + return data["config"] + + def get_queryset_kwargs(self): + kwargs = super().get_queryset_kwargs() + kwargs["config"] = self.session.config + return kwargs + + def validate_session(self, data, **context): + data = super().validate_session(data, **context) + try: + data["config"] is not None + except KeyError: + raise serializers.ValidationError( + "You must provide a configuration for this radio" + ) + return data + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + return filters.run([kwargs["config"]], candidates=qs) + + +class RelatedObjectRadio(SessionRadio): + """Abstract radio related to an object (tag, artist, user...)""" + + related_object_field = serializers.IntegerField(required=True) + + def clean(self, instance): + super().clean(instance) + if not instance.related_object: + raise ValidationError( + "Cannot start RelatedObjectRadio without related object" + ) + if not isinstance(instance.related_object, self.model): + raise ValidationError("Trying to start radio with bad related object") + + def get_related_object(self, pk): + return self.model.objects.get(pk=pk) + + +@registry.register(name="tag") +class TagRadio(RelatedObjectRadio): + model = Tag + related_object_field = serializers.CharField(required=True) + + def get_related_object(self, name): + return self.model.objects.get(name=name) + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + query = ( + Q(tagged_items__tag=self.session.related_object) + | Q(artist__tagged_items__tag=self.session.related_object) + | Q(album__tagged_items__tag=self.session.related_object) + ) + return qs.filter(query) + + def get_related_object_id_repr(self, obj): + return obj.name + + +def weighted_choice(choices): + total = sum(w for c, w in choices) + r = random.uniform(0, total) + upto = 0 + for c, w in choices: + if upto + w >= r: + return c + upto += w + assert False, "Shouldn't get here" + + +class NextNotFound(Exception): + pass + + +@registry.register(name="similar") +class SimilarRadio(RelatedObjectRadio): + model = Track + + def filter_queryset(self, queryset): + queryset = super().filter_queryset(queryset) + seeds = list( + self.session.session_tracks.all() + .values_list("track_id", flat=True) + .order_by("-id")[:3] + ) + [self.session.related_object.pk] + for seed in seeds: + try: + return queryset.filter(pk=self.find_next_id(queryset, seed)) + except NextNotFound: + continue + + return queryset.none() + + def find_next_id(self, queryset, seed): + with connection.cursor() as cursor: + query = """ + SELECT next, count(next) AS c + FROM ( + SELECT + track_id, + creation_date, + LEAD(track_id) OVER ( + PARTITION by user_id order by creation_date asc + ) AS next + FROM history_listening + INNER JOIN users_user ON (users_user.id = user_id) + WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' OR user_id = %s + ORDER BY creation_date ASC + ) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC; + """ + cursor.execute(query, [self.session.user_id, seed, seed]) + next_candidates = list(cursor.fetchall()) + + if not next_candidates: + raise NextNotFound() + + matching_tracks = list( + queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list( + "id", flat=True + ) + ) + next_candidates = [n for n in next_candidates if n[0] in matching_tracks] + if not next_candidates: + raise NextNotFound() + return random.choice([c[0] for c in next_candidates]) + + +@registry.register(name="artist") +class ArtistRadio(RelatedObjectRadio): + model = Artist + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + return qs.filter(artist=self.session.related_object) + + +@registry.register(name="less-listened") +class LessListenedRadio(SessionRadio): + def clean(self, instance): + instance.related_object = instance.user + super().clean(instance) + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + listened = self.session.user.listenings.all().values_list("track", flat=True) + return ( + qs.filter(artist__content_category="music") + .exclude(pk__in=listened) + .order_by("?") + ) + + +@registry.register(name="less-listened_library") +class LessListenedLibraryRadio(SessionRadio): + def clean(self, instance): + instance.related_object = instance.user + super().clean(instance) + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + listened = self.session.user.listenings.all().values_list("track", flat=True) + tracks_ids = self.session.user.actor.attributed_tracks.all().values_list( + "id", flat=True + ) + query = Q(artist__content_category="music") & Q(pk__in=tracks_ids) + return qs.filter(query).exclude(pk__in=listened).order_by("?") + + +@registry.register(name="actor-content") +class ActorContentRadio(RelatedObjectRadio): + """ + Play content from given actor libraries + """ + + model = federation_models.Actor + related_object_field = federation_fields.ActorRelatedField(required=True) + + def get_related_object(self, value): + return value + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + actor_uploads = Upload.objects.filter( + library__actor=self.session.related_object, + ) + return qs.filter(pk__in=actor_uploads.values("track")) + + def get_related_object_id_repr(self, obj): + return obj.full_username + + +@registry.register(name="library") +class LibraryRadio(RelatedObjectRadio): + """ + Play content from a given library + """ + + model = Library + related_object_field = serializers.UUIDField(required=True) + + def get_related_object(self, value): + return Library.objects.get(uuid=value) + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + actor_uploads = Upload.objects.filter( + library=self.session.related_object, + ) + return qs.filter(pk__in=actor_uploads.values("track")) + + def get_related_object_id_repr(self, obj): + return obj.uuid + + +@registry.register(name="recently-added") +class RecentlyAdded(SessionRadio): + def get_queryset(self, **kwargs): + date = datetime.date.today() - datetime.timedelta(days=30) + qs = super().get_queryset(**kwargs) + return qs.filter( + Q(artist__content_category="music"), + Q(creation_date__gt=date), + ) + + +# Use this to experiment on the custom multiple radio with troi +@registry.register(name="troi") +class Troi(SessionRadio): + """ + Receive a vuejs generated config and use it to launch a troi radio session. + The config data should follow : + {"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...} + Validation of the config (args) is done by troi during track fetch. + Funkwhale only checks if the patch is implemented + """ + + config = serializers.JSONField(required=True) + + def append_lb_config(self, data): + if self.session.user.settings is None: + logger.warning( + "No lb_user_name set in user settings. Some troi patches will fail" + ) + return data + elif self.session.user.settings.get("lb_user_name") is None: + logger.warning( + "No lb_user_name set in user settings. Some troi patches will fail" + ) + else: + data["user_name"] = self.session.user.settings["lb_user_name"] + + if self.session.user.settings.get("lb_user_token") is None: + logger.warning( + "No lb_user_token set in user settings. Some troi patch will fail" + ) + else: + data["user_token"] = self.session.user.settings["lb_user_token"] + + return data + + def get_queryset_kwargs(self): + kwargs = super().get_queryset_kwargs() + kwargs["config"] = self.session.config + return kwargs + + def validate_session(self, data, **context): + data = super().validate_session(data, **context) + if data.get("config") is None: + raise serializers.ValidationError( + "You must provide a configuration for this radio" + ) + return data + + def get_queryset(self, **kwargs): + qs = super().get_queryset(**kwargs) + config = self.append_lb_config(json.loads(kwargs["config"])) + + return lb_recommendations.run(config, candidates=qs) diff --git a/api/funkwhale_api/radios/registries_v2.py b/api/funkwhale_api/radios/registries_v2.py new file mode 100644 index 000000000..6cd16f206 --- /dev/null +++ b/api/funkwhale_api/radios/registries_v2.py @@ -0,0 +1,10 @@ +import persisting_theory + + +class RadioRegistry_v2(persisting_theory.Registry): + def prepare_name(self, data, name=None): + setattr(data, "radio_type", name) + return name + + +registry = RadioRegistry_v2() diff --git a/api/funkwhale_api/radios/serializers.py b/api/funkwhale_api/radios/serializers.py index 76e847d9e..16886f818 100644 --- a/api/funkwhale_api/radios/serializers.py +++ b/api/funkwhale_api/radios/serializers.py @@ -40,9 +40,11 @@ class RadioSerializer(serializers.ModelSerializer): class RadioSessionTrackSerializerCreate(serializers.ModelSerializer): + count = serializers.IntegerField(required=False, allow_null=True) + class Meta: model = models.RadioSessionTrack - fields = ("session",) + fields = ("session", "count") class RadioSessionTrackSerializer(serializers.ModelSerializer): diff --git a/api/funkwhale_api/radios/urls.py b/api/funkwhale_api/radios/urls.py index 4890b953f..7f1e3864b 100644 --- a/api/funkwhale_api/radios/urls.py +++ b/api/funkwhale_api/radios/urls.py @@ -5,7 +5,7 @@ from . import views router = routers.OptionalSlashRouter() router.register(r"sessions", views.RadioSessionViewSet, "sessions") router.register(r"radios", views.RadioViewSet, "radios") -router.register(r"tracks", views.RadioSessionTrackViewSet, "tracks") +router.register(r"tracks", views.V1_RadioSessionTrackViewSet, "tracks") urlpatterns = router.urls diff --git a/api/funkwhale_api/radios/urls_v2.py b/api/funkwhale_api/radios/urls_v2.py new file mode 100644 index 000000000..bac76f998 --- /dev/null +++ b/api/funkwhale_api/radios/urls_v2.py @@ -0,0 +1,10 @@ +from funkwhale_api.common import routers + +from . import views + +router = routers.OptionalSlashRouter() + +router.register(r"sessions", views.V2_RadioSessionViewSet, "sessions") + + +urlpatterns = router.urls diff --git a/api/funkwhale_api/radios/views.py b/api/funkwhale_api/radios/views.py index adf2fe464..07138ee35 100644 --- a/api/funkwhale_api/radios/views.py +++ b/api/funkwhale_api/radios/views.py @@ -1,3 +1,6 @@ +import pickle + +from django.core.cache import cache from django.db.models import Q from drf_spectacular.utils import extend_schema from rest_framework import mixins, status, viewsets @@ -121,7 +124,7 @@ class RadioSessionViewSet( return context -class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): +class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): serializer_class = serializers.RadioSessionTrackSerializer queryset = models.RadioSessionTrack.objects.all() permission_classes = [] @@ -133,21 +136,19 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet) session = serializer.validated_data["session"] if not request.user.is_authenticated and not request.session.session_key: self.request.session.create() - try: - assert (request.user == session.user) or ( - request.session.session_key == session.session_key - and session.session_key - ) - except AssertionError: + if not request.user == session.user or ( + not request.session.session_key == session.session_key + and not session.session_key + ): return Response(status=status.HTTP_403_FORBIDDEN) + try: - session.radio.pick() + session.radio(api_version=1).pick() except ValueError: return Response( "Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND ) session_track = session.session_tracks.all().latest("id") - # self.perform_create(serializer) # dirty override here, since we use a different serializer for creation and detail serializer = self.serializer_class( instance=session_track, context=self.get_serializer_context() @@ -161,3 +162,99 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet) if self.action == "create": return serializers.RadioSessionTrackSerializerCreate return super().get_serializer_class(*args, **kwargs) + + +class V2_RadioSessionViewSet( + mixins.CreateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet +): + """Returns a list of RadioSessions""" + + serializer_class = serializers.RadioSessionSerializer + queryset = models.RadioSession.objects.all() + permission_classes = [] + + @action(detail=True, serializer_class=serializers.RadioSessionTrackSerializerCreate) + def tracks(self, request, pk, *args, **kwargs): + data = {"session": pk} + data["count"] = ( + request.query_params["count"] + if "count" in request.query_params.keys() + else 1 + ) + serializer = serializers.RadioSessionTrackSerializerCreate(data=data) + serializer.is_valid(raise_exception=True) + session = serializer.validated_data["session"] + + count = int(data["count"]) + # this is used for test purpose. + filter_playable = ( + request.query_params["filter_playable"] + if "filter_playable" in request.query_params.keys() + else True + ) + if not request.user.is_authenticated and not request.session.session_key: + self.request.session.create() + + if not request.user == session.user or ( + not request.session.session_key == session.session_key + and not session.session_key + ): + return Response(status=status.HTTP_403_FORBIDDEN) + try: + from . import radios_v2 # noqa + + session.radio(api_version=2).pick_many( + count, filter_playable=filter_playable + ) + except ValueError: + return Response( + "Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND + ) + + # dirty override here, since we use a different serializer for creation and detail + evaluated_radio_tracks = pickle.loads(cache.get(f"radiotracks{session.id}")) + batch = evaluated_radio_tracks[:count] + serializer = TrackSerializer( + data=batch, + many="true", + ) + serializer.is_valid() + + # delete the tracks we sent from the cache + new_cached_radiotracks = evaluated_radio_tracks[count:] + cache.set(f"radiotracks{session.id}", pickle.dumps(new_cached_radiotracks)) + + return Response( + serializer.data, + status=status.HTTP_201_CREATED, + ) + + def get_queryset(self): + queryset = super().get_queryset() + if self.request.user.is_authenticated: + return queryset.filter( + Q(user=self.request.user) + | Q(session_key=self.request.session.session_key) + ) + + return queryset.filter(session_key=self.request.session.session_key).exclude( + session_key=None + ) + + def perform_create(self, serializer): + if ( + not self.request.user.is_authenticated + and not self.request.session.session_key + ): + self.request.session.create() + return serializer.save( + user=self.request.user if self.request.user.is_authenticated else None, + session_key=self.request.session.session_key, + ) + + def get_serializer_context(self): + context = super().get_serializer_context() + context["user"] = ( + self.request.user if self.request.user.is_authenticated else None + ) + return context diff --git a/api/tests/radios/test_radios.py b/api/tests/radios/test_radios.py index 1e9c02321..53c604681 100644 --- a/api/tests/radios/test_radios.py +++ b/api/tests/radios/test_radios.py @@ -2,8 +2,8 @@ import json import random import pytest -from django.core.exceptions import ValidationError from django.urls import reverse +from rest_framework.exceptions import ValidationError from funkwhale_api.favorites.models import TrackFavorite from funkwhale_api.radios import models, radios, serializers @@ -98,7 +98,7 @@ def test_can_get_choices_for_custom_radio(factories): session = factories["radios.CustomRadioSession"]( custom_radio__config=[{"type": "artist", "ids": [artist.pk]}] ) - choices = session.radio.get_choices(filter_playable=False) + choices = session.radio(api_version=1).get_choices(filter_playable=False) expected = [t.pk for t in tracks] assert list(choices.values_list("id", flat=True)) == expected @@ -191,16 +191,17 @@ def test_can_get_track_for_session_from_api(factories, logged_in_api_client): def test_related_object_radio_validate_related_object(factories): - user = factories["users.User"]() # cannot start without related object - radio = radios.ArtistRadio() + radio = {"radio_type": "tag"} + serializer = serializers.RadioSessionSerializer() with pytest.raises(ValidationError): - radio.start_session(user) + serializer.validate(data=radio) # cannot start with bad related object type - radio = radios.ArtistRadio() + radio = {"radio_type": "tag", "related_object": "whatever"} + serializer = serializers.RadioSessionSerializer() with pytest.raises(ValidationError): - radio.start_session(user, related_object=user) + serializer.validate(data=radio) def test_can_start_artist_radio(factories): @@ -391,7 +392,7 @@ def test_get_choices_for_custom_radio_exclude_artist(factories): {"type": "artist", "ids": [excluded_artist.pk], "not": True}, ] ) - choices = session.radio.get_choices(filter_playable=False) + choices = session.radio(api_version=1).get_choices(filter_playable=False) expected = [u.track.pk for u in included_uploads] assert list(choices.values_list("id", flat=True)) == expected @@ -409,7 +410,7 @@ def test_get_choices_for_custom_radio_exclude_tag(factories): {"type": "tag", "names": ["rock"], "not": True}, ] ) - choices = session.radio.get_choices(filter_playable=False) + choices = session.radio(api_version=1).get_choices(filter_playable=False) expected = [u.track.pk for u in included_uploads] assert list(choices.values_list("id", flat=True)) == expected @@ -429,28 +430,3 @@ def test_can_start_custom_multiple_radio_from_api(api_client, factories): format="json", ) assert response.status_code == 201 - - -def test_can_start_periodic_jams_troi_radio_from_api(api_client, factories): - factories["music.Track"].create_batch(5) - url = reverse("api:v1:radios:sessions-list") - config = {"patch": "periodic-jams", "type": "daily-jams"} - response = api_client.post( - url, - {"radio_type": "troi", "config": config}, - format="json", - ) - assert response.status_code == 201 - - -# to do : send error to api ? -def test_can_catch_troi_radio_error(api_client, factories): - factories["music.Track"].create_batch(5) - url = reverse("api:v1:radios:sessions-list") - config = {"patch": "periodic-jams", "type": "not_existing_type"} - response = api_client.post( - url, - {"radio_type": "troi", "config": config}, - format="json", - ) - assert response.status_code == 201 diff --git a/api/tests/radios/test_radios_v2.py b/api/tests/radios/test_radios_v2.py new file mode 100644 index 000000000..85fc15db6 --- /dev/null +++ b/api/tests/radios/test_radios_v2.py @@ -0,0 +1,144 @@ +import json +import logging +import pickle +import random + +from django.core.cache import cache +from django.urls import reverse + +from funkwhale_api.favorites.models import TrackFavorite +from funkwhale_api.radios import models, radios_v2 + + +def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client): + actor = logged_in_api_client.user.create_actor() + track = factories["music.Upload"]( + library__actor=actor, import_status="finished" + ).track + url = reverse("api:v2:radios:sessions-list") + response = logged_in_api_client.post(url, {"radio_type": "random"}) + session = models.RadioSession.objects.latest("id") + + url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk}) + response = logged_in_api_client.get(url, {"session": session.pk}) + data = json.loads(response.content.decode("utf-8")) + + assert data[0]["id"] == track.pk + + next_track = factories["music.Upload"]( + library__actor=actor, import_status="finished" + ).track + response = logged_in_api_client.get(url, {"session": session.pk}) + data = json.loads(response.content.decode("utf-8")) + + assert data[0]["id"] == next_track.id + + +def test_can_use_radio_session_to_filter_choices_v2(factories): + factories["music.Upload"].create_batch(10) + user = factories["users.User"]() + radio = radios_v2.RandomRadio() + session = radio.start_session(user) + + radio.pick_many(quantity=10, filter_playable=False) + + # ensure 10 different tracks have been suggested + tracks_id = [ + session_track.track.pk for session_track in session.session_tracks.all() + ] + assert len(set(tracks_id)) == 10 + + +def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_client): + tracks = factories["music.Track"].create_batch(5) + url = reverse("api:v2:radios:sessions-list") + response = logged_in_api_client.post(url, {"radio_type": "random"}) + session = models.RadioSession.objects.latest("id") + url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk}) + + previous_choices = [] + + for i in range(5): + response = logged_in_api_client.get( + url, {"session": session.pk, "filter_playable": False} + ) + pick = json.loads(response.content.decode("utf-8")) + assert pick[0]["title"] not in previous_choices + assert pick[0]["title"] in [t.title for t in tracks] + previous_choices.append(pick[0]["title"]) + + response = logged_in_api_client.get(url, {"session": session.pk}) + assert ( + json.loads(response.content.decode("utf-8")) + == "Radio doesn't have more candidates" + ) + + +def test_can_get_choices_for_favorites_radio_v2(factories): + files = factories["music.Upload"].create_batch(10) + tracks = [f.track for f in files] + user = factories["users.User"]() + for i in range(5): + TrackFavorite.add(track=random.choice(tracks), user=user) + + radio = radios_v2.FavoritesRadio() + session = radio.start_session(user=user) + choices = session.radio(api_version=2).get_choices( + quantity=100, filter_playable=False + ) + + assert len(choices) == user.track_favorites.all().count() + + for favorite in user.track_favorites.all(): + assert favorite.track in choices + + +def test_can_get_choices_for_custom_radio_v2(factories): + artist = factories["music.Artist"]() + files = factories["music.Upload"].create_batch(5, track__artist=artist) + tracks = [f.track for f in files] + factories["music.Upload"].create_batch(5) + + session = factories["radios.CustomRadioSession"]( + custom_radio__config=[{"type": "artist", "ids": [artist.pk]}] + ) + choices = session.radio(api_version=2).get_choices( + quantity=1, filter_playable=False + ) + + expected = [t.pk for t in tracks] + for t in choices: + assert t.id in expected + + +def test_can_cache_radio_track(factories): + uploads = factories["music.Track"].create_batch(10) + user = factories["users.User"]() + radio = radios_v2.RandomRadio() + session = radio.start_session(user) + picked = session.radio(api_version=2).pick_many(quantity=1, filter_playable=False) + assert len(picked) == 1 + for t in pickle.loads(cache.get(f"radiotracks{session.id}")): + assert t in uploads + + +def test_regenerate_cache_if_not_enought_tracks_in_it( + factories, caplog, logged_in_api_client +): + logger = logging.getLogger("funkwhale_api.radios.radios_v2") + caplog.set_level(logging.INFO) + logger.addHandler(caplog.handler) + + factories["music.Track"].create_batch(10) + factories["users.User"]() + url = reverse("api:v2:radios:sessions-list") + response = logged_in_api_client.post(url, {"radio_type": "random"}) + session = models.RadioSession.objects.latest("id") + url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk}) + logged_in_api_client.get(url, {"count": 9, "filter_playable": False}) + response = logged_in_api_client.get(url, {"count": 10, "filter_playable": False}) + pick = json.loads(response.content.decode("utf-8")) + assert ( + "Not enough radio tracks in cache. Trying to generate new cache" in caplog.text + ) + assert len(pick) == 1 diff --git a/changes/changelog.d/2135.feature b/changes/changelog.d/2135.feature new file mode 100644 index 000000000..d099eeb56 --- /dev/null +++ b/changes/changelog.d/2135.feature @@ -0,0 +1 @@ +Cache radio queryset into redis. New radio track endpoint for api v2 is /api/v2/radios/sessions/{radiosessionid}/tracks (#2135) diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index b947a6dbc..fe32450ad 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -98,6 +98,8 @@ services: env_file: - .env image: typesense/typesense:0.24.0 + networks: + - internal volumes: - ./typesense/data:/data command: --data-dir /data --enable-cors